chore: merge upstream Wei-Shaw/sub2api v0.1.132

Conflicts resolved (preserving fork customizations):
- config.go: keep NodeTLSProxy + add upstream OpenAIHTTP2
- gateway_service.go: NewGatewayService now takes both rpmTokenBucketSvc
  (local) and userPlatformQuotaRepo (upstream)
- wire_gen.go: wire both new args into the call site
- http_upstream.go: drop redundant settings re-assignment; keep proxy
  URL log redaction
- http_upstream_test.go: adopt upstream's explicit-0-disables semantics;
  keep 600s default constant in nil-cfg fallback test
- user_handler_test.go / gateway_record_usage_test.go: pick up new
  userPlatformQuotaRepo nil parameter

Also updated test stubs (windsurf_google_login_test.go,
windsurf_tier_access_service_test.go, gateway_models_test.go) for new
SetModelRateLimit variadic signature and the extra NewGatewayService arg.

Upstream highlights: OpenAI embeddings gateway, user x platform USD
quota, content-moderation risk thresholds, OAuth 401 credentials
no-overwrite fix, HTTP/2 OpenAI upstream config, pool retry status code
configurability, long-context cache pricing multipliers.
This commit is contained in:
win 2026-05-29 07:21:32 +08:00
commit f519a02ec9
259 changed files with 22370 additions and 836 deletions

View File

@ -103,9 +103,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
<tr>
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
<td>Thanks to PatewayAI for sponsoring this project! <a href="https://pateway.ai/?ch=1tsfr51">PatewayAI</a> is a premium API relay built for heavy AI developers, offering the full Claude and Codex series sourced 100% from official providers, with transparent token-level billing. Enterprise plans include high concurrency, dedicated management, contracts, and invoicing. Register now to get $3 in trial credits, top-ups from 60% off, and referral bonuses up to $150.</td>
</tr>
<tr>
@ -120,6 +118,18 @@ Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to recei
</td>
</tr>
<tr>
<td width="180"><a href="https://unity2.ai/register?source=sub2api"><img src="assets/partners/logos/unity2.png" alt="unity2" width="150"></a></td>
<td>Thanks to Unity2 for sponsoring this project! <a href="https://unity2.ai/register?source=sub2api">Unity2</a> is a high-performance AI model API relay for individuals, teams, and enterprises, handling 30B+ tokens/day with 5000 RPM concurrency. One API Key works across Claude Code, Codex, OpenAI models, IDE plugins, and Agent workflows, with balance billing, bundled subscriptions, enterprise invoicing, and 1-on-1 support. <a href="https://unity2.ai/register?source=sub2api">Register</a> to claim $2 in balance, plus $10 more by joining the official group — up to $12 in free credit.
</td>
</tr>
<tr>
<td width="180"><a href="https://veilx.io/#/hello/SJRBRVDV"><img src="assets/partners/logos/veilx.png" alt="veilx" width="150"></a></td>
<td>Thanks to Veilx for sponsoring this project! <a href="https://veilx.io/#/hello/SJRBRVDV">Veilx</a> CDN is purpose-built for large-scale AI API traffic, deeply optimized for relay services and call chains across OpenAI, Claude, Gemini, and scenarios like chat, image generation, embeddings, and streaming — delivering lower latency and higher stability under heavy concurrency. It also offers China three-network optimized return lines, making it ideal for global AI relay platforms, overseas AI SaaS, and cross-border high-concurrency deployments.
</td>
</tr>
</table>
## Ecosystem

View File

@ -119,6 +119,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
</td>
</tr>
<tr>
<td width="180"><a href="https://unity2.ai/register?source=sub2api"><img src="assets/partners/logos/unity2.png" alt="unity2" width="150"></a></td>
<td>感谢 Unity2 赞助本项目! <a href="https://unity2.ai/register?source=sub2api">Unity2</a> 是面向个人开发者、团队、企业的高性能 AI 模型 API 中转平台,长期服务国内头部企业,日均承载超 300 亿 token 调用,支持 5000 RPM 级高并发。一个 API Key 即可适配 Claude Code、Codex、OpenAI 模型、IDE 插件和 Agent 工作流等场景。具备企业级稳定供应能力,在高并发、持续调用和团队集中采购场景下依然保持低延迟、高可用。同时支持余额计费、组合订阅、首充优惠、企业开票、专属 1v1 对接,适合个人高频使用和企业长期接入。现在注册 Unity2.ai 可领取 $2 余额,加入官方群再送 $10 余额,合计最高可领 $12 免费额度,适合先体验后长期使用。<a href="https://unity2.ai/register?source=sub2api">注册链接</a>
</td>
</tr>
<tr>
<td width="180"><a href="https://veilx.io/#/hello/SJRBRVDV"><img src="assets/partners/logos/veilx.png" alt="veilx" width="150"></a></td>
<td>感谢 Veilx 赞助本项目! <a href="https://veilx.io/#/hello/SJRBRVDV">Veilx</a> CDN 专为超大规模 API 请求场景打造,针对 AI 中转站业务与 AI API 调用链路进行了深度优化,轻松应对高并发、高频请求与大流量传输,为开发者与企业提供更快、更稳、更低延迟的加速体验。无论是 OpenAI、Claude、Gemini 等 AI 接口中转还是聊天、绘图、Embedding、流式输出等复杂场景Veilx 都能显著提升响应速度与连接稳定性有效降低网络波动带来的超时与失败问题。同时Veilx 提供中国三网优化回国极速线路,大幅提升中国大陆地区访问海外 AI 服务的速度与稳定性,特别适合全球 AI 中转平台、海外 AI SaaS、跨境业务与高并发 API 系统部署。专为 AI API 而生,让你的 AI 中转服务更快、更稳、更省心。<a href="https://veilx.io/#/hello/SJRBRVDV">购买地址</a>
</td>
</tr>
</table>
## 生态项目

View File

@ -119,6 +119,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
</td>
</tr>
<tr>
<td width="180"><a href="https://unity2.ai/register?source=sub2api"><img src="assets/partners/logos/unity2.png" alt="unity2" width="150"></a></td>
<td>Unity2 のご支援に感謝します!<a href="https://unity2.ai/register?source=sub2api">Unity2</a> は個人開発者、チーム、企業向けの高性能 AI モデル API 中継プラットフォームです。中国の大手企業に長期にわたりサービスを提供しており、1日あたり 300 億以上のトークン呼び出しを処理し、5000 RPM 級の高並列性をサポートします。1つの API キーで Claude Code、Codex、OpenAI モデル、IDE プラグイン、Agent ワークフローなど様々なシナリオに対応できます。エンタープライズグレードの安定供給能力を備え、高並列・継続的な呼び出し・チームの集中購入シーンでも低レイテンシと高可用性を維持します。残高課金、組み合わせサブスクリプション、初回チャージ特典、企業向け請求書発行、専属 1v1 サポートにも対応しており、個人の頻繁な利用にも企業の長期導入にも適しています。今 Unity2.ai に登録すると $2 の残高、公式グループに参加するとさらに $10 の残高がもらえ、合計最大 $12 の無料クレジットを獲得できます — 試用後に長期利用したい方に最適です。<a href="https://unity2.ai/register?source=sub2api">登録リンク</a>
</td>
</tr>
<tr>
<td width="180"><a href="https://veilx.io/#/hello/SJRBRVDV"><img src="assets/partners/logos/veilx.png" alt="veilx" width="150"></a></td>
<td>Veilx のご支援に感謝します!<a href="https://veilx.io/#/hello/SJRBRVDV">Veilx</a> CDN は超大規模 API リクエストシナリオ向けに設計されており、AI 中継サービスと AI API 呼び出しチェーンに対して深く最適化されています。高並列・高頻度リクエスト・大容量トラフィックに容易に対応し、開発者と企業により高速で安定した、低レイテンシの加速体験を提供します。OpenAI、Claude、Gemini などの AI インターフェース中継はもちろん、チャット、画像生成、Embedding、ストリーミング出力などの複雑なシナリオでも、Veilx は応答速度と接続安定性を大幅に向上させ、ネットワーク変動によるタイムアウトや失敗を効果的に削減します。さらに、Veilx は中国三大ネットワーク最適化の高速回線を提供しており、中国本土から海外 AI サービスへのアクセス速度と安定性を大幅に向上させます。グローバル AI 中継プラットフォーム、海外 AI SaaS、越境ビジネス、高並列 API システム展開に特に適しています。AI API のために生まれ、あなたの AI 中継サービスをより速く、より安定して、より安心に。<a href="https://veilx.io/#/hello/SJRBRVDV">購入リンク</a>
</td>
</tr>
</table>
## エコシステム

Binary file not shown.

After

Width:  |  Height:  |  Size: 744 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 376 KiB

View File

@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@ -1 +1 @@
0.1.130
0.1.132

View File

@ -63,7 +63,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
userRPMCache := repository.NewUserRPMCache(redisClient)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
userPlatformQuotaRepository := repository.NewUserPlatformQuotaRepository(client)
serviceUserPlatformQuotaRepository := repository.NewUserPlatformQuotaServiceAdapter(userPlatformQuotaRepository)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig, serviceUserPlatformQuotaRepository)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
@ -71,7 +73,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
affiliateRepository := repository.NewAffiliateRepository(client, db)
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService, serviceUserPlatformQuotaRepository)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
@ -85,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService, serviceUserPlatformQuotaRepository)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@ -143,9 +145,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService, serviceUserPlatformQuotaRepository)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService, serviceUserPlatformQuotaRepository, billingCache)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
@ -194,7 +196,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache)
digestSessionStore := service.NewDigestSessionStore()
rpmTokenBucketService := service.NewRPMTokenBucketService()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService, serviceUserPlatformQuotaRepository)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)

View File

@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
pricingSvc := service.NewPricingService(cfg, nil)
emailQueueSvc := service.NewEmailQueueService(nil, 1)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)

View File

@ -48,6 +48,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
stdsql "database/sql"
@ -124,6 +125,8 @@ type Client struct {
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
UserPlatformQuota *UserPlatformQuotaClient
// UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient
}
@ -170,6 +173,7 @@ func (c *Client) init() {
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
c.UserPlatformQuota = NewUserPlatformQuotaClient(c.config)
c.UserSubscription = NewUserSubscriptionClient(c.config)
}
@ -296,6 +300,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@ -349,6 +354,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@ -388,7 +394,7 @@ func (c *Client) Use(hooks ...Hook) {
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.UserPlatformQuota, c.UserSubscription,
} {
n.Use(hooks...)
}
@ -407,7 +413,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.UserPlatformQuota, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@ -482,6 +488,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.UserAttributeDefinition.mutate(ctx, m)
case *UserAttributeValueMutation:
return c.UserAttributeValue.mutate(ctx, m)
case *UserPlatformQuotaMutation:
return c.UserPlatformQuota.mutate(ctx, m)
case *UserSubscriptionMutation:
return c.UserSubscription.mutate(ctx, m)
default:
@ -5341,6 +5349,22 @@ func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery
return query
}
// QueryPlatformQuotas queries the platform_quotas edge of a User.
func (c *UserClient) QueryPlatformQuotas(_m *User) *UserPlatformQuotaQuery {
query := (&UserPlatformQuotaClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@ -5816,6 +5840,157 @@ func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeV
}
}
// UserPlatformQuotaClient is a client for the UserPlatformQuota schema.
type UserPlatformQuotaClient struct {
config
}
// NewUserPlatformQuotaClient returns a client for the UserPlatformQuota from the given config.
func NewUserPlatformQuotaClient(c config) *UserPlatformQuotaClient {
return &UserPlatformQuotaClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `userplatformquota.Hooks(f(g(h())))`.
func (c *UserPlatformQuotaClient) Use(hooks ...Hook) {
c.hooks.UserPlatformQuota = append(c.hooks.UserPlatformQuota, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `userplatformquota.Intercept(f(g(h())))`.
func (c *UserPlatformQuotaClient) Intercept(interceptors ...Interceptor) {
c.inters.UserPlatformQuota = append(c.inters.UserPlatformQuota, interceptors...)
}
// Create returns a builder for creating a UserPlatformQuota entity.
func (c *UserPlatformQuotaClient) Create() *UserPlatformQuotaCreate {
mutation := newUserPlatformQuotaMutation(c.config, OpCreate)
return &UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UserPlatformQuota entities.
func (c *UserPlatformQuotaClient) CreateBulk(builders ...*UserPlatformQuotaCreate) *UserPlatformQuotaCreateBulk {
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserPlatformQuotaClient) MapCreateBulk(slice any, setFunc func(*UserPlatformQuotaCreate, int)) *UserPlatformQuotaCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserPlatformQuotaCreateBulk{err: fmt.Errorf("calling to UserPlatformQuotaClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserPlatformQuotaCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UserPlatformQuota.
func (c *UserPlatformQuotaClient) Update() *UserPlatformQuotaUpdate {
mutation := newUserPlatformQuotaMutation(c.config, OpUpdate)
return &UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserPlatformQuotaClient) UpdateOne(_m *UserPlatformQuota) *UserPlatformQuotaUpdateOne {
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuota(_m))
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserPlatformQuotaClient) UpdateOneID(id int64) *UserPlatformQuotaUpdateOne {
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuotaID(id))
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UserPlatformQuota.
func (c *UserPlatformQuotaClient) Delete() *UserPlatformQuotaDelete {
mutation := newUserPlatformQuotaMutation(c.config, OpDelete)
return &UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserPlatformQuotaClient) DeleteOne(_m *UserPlatformQuota) *UserPlatformQuotaDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserPlatformQuotaClient) DeleteOneID(id int64) *UserPlatformQuotaDeleteOne {
builder := c.Delete().Where(userplatformquota.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserPlatformQuotaDeleteOne{builder}
}
// Query returns a query builder for UserPlatformQuota.
func (c *UserPlatformQuotaClient) Query() *UserPlatformQuotaQuery {
return &UserPlatformQuotaQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUserPlatformQuota},
inters: c.Interceptors(),
}
}
// Get returns a UserPlatformQuota entity by its id.
func (c *UserPlatformQuotaClient) Get(ctx context.Context, id int64) (*UserPlatformQuota, error) {
return c.Query().Where(userplatformquota.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserPlatformQuotaClient) GetX(ctx context.Context, id int64) *UserPlatformQuota {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryUser queries the user edge of a UserPlatformQuota.
func (c *UserPlatformQuotaClient) QueryUser(_m *UserPlatformQuota) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserPlatformQuotaClient) Hooks() []Hook {
hooks := c.hooks.UserPlatformQuota
return append(hooks[:len(hooks):len(hooks)], userplatformquota.Hooks[:]...)
}
// Interceptors returns the client interceptors.
func (c *UserPlatformQuotaClient) Interceptors() []Interceptor {
inters := c.inters.UserPlatformQuota
return append(inters[:len(inters):len(inters)], userplatformquota.Interceptors[:]...)
}
func (c *UserPlatformQuotaClient) mutate(ctx context.Context, m *UserPlatformQuotaMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UserPlatformQuota mutation op: %q", m.Op())
}
}
// UserSubscriptionClient is a client for the UserSubscription schema.
type UserSubscriptionClient struct {
config
@ -6025,7 +6200,8 @@ type (
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
@ -6035,7 +6211,8 @@ type (
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
UserSubscription []ent.Interceptor
}
)

View File

@ -45,6 +45,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -139,6 +140,7 @@ func checkColumn(t, c string) error {
userallowedgroup.Table: userallowedgroup.ValidColumn,
userattributedefinition.Table: userattributedefinition.ValidColumn,
userattributevalue.Table: userattributevalue.ValidColumn,
userplatformquota.Table: userplatformquota.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
})
})

View File

@ -85,6 +85,8 @@ type Group struct {
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
// 自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度
ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config,omitempty"`
// 分组 RPM 上限0 表示不限制;设置后接管该分组用户的限流
RpmLimit int `json:"rpm_limit,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
@ -193,7 +195,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig, group.FieldModelsListConfig:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool)
@ -440,6 +442,14 @@ func (_m *Group) assignValues(columns []string, values []any) error {
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
}
}
case group.FieldModelsListConfig:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field models_list_config", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ModelsListConfig); err != nil {
return fmt.Errorf("unmarshal field models_list_config: %w", err)
}
}
case group.FieldRpmLimit:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
@ -641,6 +651,9 @@ func (_m *Group) String() string {
builder.WriteString("messages_dispatch_model_config=")
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
builder.WriteString(", ")
builder.WriteString("models_list_config=")
builder.WriteString(fmt.Sprintf("%v", _m.ModelsListConfig))
builder.WriteString(", ")
builder.WriteString("rpm_limit=")
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
builder.WriteByte(')')

View File

@ -82,6 +82,8 @@ const (
FieldDefaultMappedModel = "default_mapped_model"
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
// FieldModelsListConfig holds the string denoting the models_list_config field in the database.
FieldModelsListConfig = "models_list_config"
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
FieldRpmLimit = "rpm_limit"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
@ -192,6 +194,7 @@ var Columns = []string{
FieldRequirePrivacySet,
FieldDefaultMappedModel,
FieldMessagesDispatchModelConfig,
FieldModelsListConfig,
FieldRpmLimit,
}
@ -276,6 +279,8 @@ var (
DefaultMappedModelValidator func(string) error
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
// DefaultModelsListConfig holds the default value on creation for the "models_list_config" field.
DefaultModelsListConfig domain.GroupModelsListConfig
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
DefaultRpmLimit int
)

View File

@ -467,6 +467,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
return _c
}
// SetModelsListConfig sets the "models_list_config" field.
func (_c *GroupCreate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupCreate {
_c.mutation.SetModelsListConfig(v)
return _c
}
// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil.
func (_c *GroupCreate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupCreate {
if v != nil {
_c.SetModelsListConfig(*v)
}
return _c
}
// SetRpmLimit sets the "rpm_limit" field.
func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate {
_c.mutation.SetRpmLimit(v)
@ -698,6 +712,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultMessagesDispatchModelConfig
_c.mutation.SetMessagesDispatchModelConfig(v)
}
if _, ok := _c.mutation.ModelsListConfig(); !ok {
v := group.DefaultModelsListConfig
_c.mutation.SetModelsListConfig(v)
}
if _, ok := _c.mutation.RpmLimit(); !ok {
v := group.DefaultRpmLimit
_c.mutation.SetRpmLimit(v)
@ -798,6 +816,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
}
if _, ok := _c.mutation.ModelsListConfig(); !ok {
return &ValidationError{Name: "models_list_config", err: errors.New(`ent: missing required field "Group.models_list_config"`)}
}
if _, ok := _c.mutation.RpmLimit(); !ok {
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)}
}
@ -960,6 +981,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
_node.MessagesDispatchModelConfig = value
}
if value, ok := _c.mutation.ModelsListConfig(); ok {
_spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value)
_node.ModelsListConfig = value
}
if value, ok := _c.mutation.RpmLimit(); ok {
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
_node.RpmLimit = value
@ -1642,6 +1667,18 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
return u
}
// SetModelsListConfig sets the "models_list_config" field.
func (u *GroupUpsert) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsert {
u.Set(group.FieldModelsListConfig, v)
return u
}
// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create.
func (u *GroupUpsert) UpdateModelsListConfig() *GroupUpsert {
u.SetExcluded(group.FieldModelsListConfig)
return u
}
// SetRpmLimit sets the "rpm_limit" field.
func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert {
u.Set(group.FieldRpmLimit, v)
@ -2314,6 +2351,20 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
})
}
// SetModelsListConfig sets the "models_list_config" field.
func (u *GroupUpsertOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetModelsListConfig(v)
})
}
// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateModelsListConfig() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelsListConfig()
})
}
// SetRpmLimit sets the "rpm_limit" field.
func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@ -3155,6 +3206,20 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
})
}
// SetModelsListConfig sets the "models_list_config" field.
func (u *GroupUpsertBulk) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetModelsListConfig(v)
})
}
// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateModelsListConfig() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateModelsListConfig()
})
}
// SetRpmLimit sets the "rpm_limit" field.
func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {

View File

@ -616,6 +616,20 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
return _u
}
// SetModelsListConfig sets the "models_list_config" field.
func (_u *GroupUpdate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdate {
_u.mutation.SetModelsListConfig(v)
return _u
}
// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdate {
if v != nil {
_u.SetModelsListConfig(*v)
}
return _u
}
// SetRpmLimit sets the "rpm_limit" field.
func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate {
_u.mutation.ResetRpmLimit()
@ -1112,6 +1126,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
if value, ok := _u.mutation.ModelsListConfig(); ok {
_spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value)
}
if value, ok := _u.mutation.RpmLimit(); ok {
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
}
@ -2012,6 +2029,20 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA
return _u
}
// SetModelsListConfig sets the "models_list_config" field.
func (_u *GroupUpdateOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdateOne {
_u.mutation.SetModelsListConfig(v)
return _u
}
// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdateOne {
if v != nil {
_u.SetModelsListConfig(*v)
}
return _u
}
// SetRpmLimit sets the "rpm_limit" field.
func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne {
_u.mutation.ResetRpmLimit()
@ -2538,6 +2569,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
if value, ok := _u.mutation.ModelsListConfig(); ok {
_spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value)
}
if value, ok := _u.mutation.RpmLimit(); ok {
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
}

View File

@ -405,6 +405,18 @@ func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
}
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary
// function as UserPlatformQuota mutator.
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserPlatformQuotaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserPlatformQuotaMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserPlatformQuotaMutation", m)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
// function as UserSubscription mutator.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)

View File

@ -42,6 +42,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -992,6 +993,33 @@ func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) e
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
}
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserPlatformQuotaFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
}
// The TraverseUserPlatformQuota type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUserPlatformQuota func(context.Context, *ent.UserPlatformQuotaQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUserPlatformQuota) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUserPlatformQuota) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
@ -1088,6 +1116,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
case *ent.UserAttributeValueQuery:
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
case *ent.UserPlatformQuotaQuery:
return &query[*ent.UserPlatformQuotaQuery, predicate.UserPlatformQuota, userplatformquota.OrderOption]{typ: ent.TypeUserPlatformQuota, tq: q}, nil
case *ent.UserSubscriptionQuery:
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
default:

View File

@ -669,6 +669,7 @@ var (
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "models_list_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "rpm_limit", Type: field.TypeInt, Default: 0},
}
// GroupsTable holds the schema information for the "groups" table.
@ -1612,6 +1613,53 @@ var (
},
},
}
// UserPlatformQuotasColumns holds the columns for the "user_platform_quotas" table.
UserPlatformQuotasColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "platform", Type: field.TypeString, Size: 32},
{Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "daily_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "weekly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "monthly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "daily_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "weekly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "monthly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "user_id", Type: field.TypeInt64},
}
// UserPlatformQuotasTable holds the schema information for the "user_platform_quotas" table.
UserPlatformQuotasTable = &schema.Table{
Name: "user_platform_quotas",
Columns: UserPlatformQuotasColumns,
PrimaryKey: []*schema.Column{UserPlatformQuotasColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "user_platform_quotas_users_platform_quotas",
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "userplatformquota_user_id_platform",
Unique: true,
Columns: []*schema.Column{UserPlatformQuotasColumns[14], UserPlatformQuotasColumns[4]},
Annotation: &entsql.IndexAnnotation{
Where: "deleted_at IS NULL",
},
},
{
Name: "userplatformquota_user_id",
Unique: false,
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
},
},
}
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
UserSubscriptionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@ -1736,6 +1784,7 @@ var (
UserAllowedGroupsTable,
UserAttributeDefinitionsTable,
UserAttributeValuesTable,
UserPlatformQuotasTable,
UserSubscriptionsTable,
}
)
@ -1869,6 +1918,10 @@ func init() {
UserAttributeValuesTable.Annotation = &entsql.Annotation{
Table: "user_attribute_values",
}
UserPlatformQuotasTable.ForeignKeys[0].RefTable = UsersTable
UserPlatformQuotasTable.Annotation = &entsql.Annotation{
Table: "user_platform_quotas",
}
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable

File diff suppressed because it is too large Load Diff

View File

@ -105,5 +105,8 @@ type UserAttributeDefinition func(*sql.Selector)
// UserAttributeValue is the predicate function for userattributevalue builders.
type UserAttributeValue func(*sql.Selector)
// UserPlatformQuota is the predicate function for userplatformquota builders.
type UserPlatformQuota func(*sql.Selector)
// UserSubscription is the predicate function for usersubscription builders.
type UserSubscription func(*sql.Selector)

View File

@ -39,6 +39,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
@ -869,8 +870,12 @@ func init() {
groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
// groupDescModelsListConfig is the schema descriptor for models_list_config field.
groupDescModelsListConfig := groupFields[30].Descriptor()
// group.DefaultModelsListConfig holds the default value on creation for the models_list_config field.
group.DefaultModelsListConfig = groupDescModelsListConfig.Default.(domain.GroupModelsListConfig)
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
groupDescRpmLimit := groupFields[30].Descriptor()
groupDescRpmLimit := groupFields[31].Descriptor()
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
@ -1997,6 +2002,56 @@ func init() {
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
// userattributevalue.DefaultValue holds the default value on creation for the value field.
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
userplatformquotaMixin := schema.UserPlatformQuota{}.Mixin()
userplatformquotaMixinHooks1 := userplatformquotaMixin[1].Hooks()
userplatformquota.Hooks[0] = userplatformquotaMixinHooks1[0]
userplatformquotaMixinInters1 := userplatformquotaMixin[1].Interceptors()
userplatformquota.Interceptors[0] = userplatformquotaMixinInters1[0]
userplatformquotaMixinFields0 := userplatformquotaMixin[0].Fields()
_ = userplatformquotaMixinFields0
userplatformquotaFields := schema.UserPlatformQuota{}.Fields()
_ = userplatformquotaFields
// userplatformquotaDescCreatedAt is the schema descriptor for created_at field.
userplatformquotaDescCreatedAt := userplatformquotaMixinFields0[0].Descriptor()
// userplatformquota.DefaultCreatedAt holds the default value on creation for the created_at field.
userplatformquota.DefaultCreatedAt = userplatformquotaDescCreatedAt.Default.(func() time.Time)
// userplatformquotaDescUpdatedAt is the schema descriptor for updated_at field.
userplatformquotaDescUpdatedAt := userplatformquotaMixinFields0[1].Descriptor()
// userplatformquota.DefaultUpdatedAt holds the default value on creation for the updated_at field.
userplatformquota.DefaultUpdatedAt = userplatformquotaDescUpdatedAt.Default.(func() time.Time)
// userplatformquota.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
userplatformquota.UpdateDefaultUpdatedAt = userplatformquotaDescUpdatedAt.UpdateDefault.(func() time.Time)
// userplatformquotaDescPlatform is the schema descriptor for platform field.
userplatformquotaDescPlatform := userplatformquotaFields[1].Descriptor()
// userplatformquota.PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
userplatformquota.PlatformValidator = func() func(string) error {
validators := userplatformquotaDescPlatform.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
validators[2].(func(string) error),
}
return func(platform string) error {
for _, fn := range fns {
if err := fn(platform); err != nil {
return err
}
}
return nil
}
}()
// userplatformquotaDescDailyUsageUsd is the schema descriptor for daily_usage_usd field.
userplatformquotaDescDailyUsageUsd := userplatformquotaFields[5].Descriptor()
// userplatformquota.DefaultDailyUsageUsd holds the default value on creation for the daily_usage_usd field.
userplatformquota.DefaultDailyUsageUsd = userplatformquotaDescDailyUsageUsd.Default.(float64)
// userplatformquotaDescWeeklyUsageUsd is the schema descriptor for weekly_usage_usd field.
userplatformquotaDescWeeklyUsageUsd := userplatformquotaFields[6].Descriptor()
// userplatformquota.DefaultWeeklyUsageUsd holds the default value on creation for the weekly_usage_usd field.
userplatformquota.DefaultWeeklyUsageUsd = userplatformquotaDescWeeklyUsageUsd.Default.(float64)
// userplatformquotaDescMonthlyUsageUsd is the schema descriptor for monthly_usage_usd field.
userplatformquotaDescMonthlyUsageUsd := userplatformquotaFields[7].Descriptor()
// userplatformquota.DefaultMonthlyUsageUsd holds the default value on creation for the monthly_usage_usd field.
userplatformquota.DefaultMonthlyUsageUsd = userplatformquotaDescMonthlyUsageUsd.Default.(float64)
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]

View File

@ -155,6 +155,10 @@ func (Group) Fields() []ent.Field {
Default(domain.OpenAIMessagesDispatchModelConfig{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
field.JSON("models_list_config", domain.GroupModelsListConfig{}).
Default(domain.GroupModelsListConfig{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度"),
// 分组级每分钟请求数上限0 = 不限制)。设置后优先于用户级兜底生效。
field.Int("rpm_limit").

View File

@ -131,6 +131,7 @@ func (User) Edges() []ent.Edge {
edge.To("auth_identities", AuthIdentity.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
edge.To("platform_quotas", UserPlatformQuota.Type),
}
}

View File

@ -0,0 +1,113 @@
package schema
import (
"fmt"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
)
// UserPlatformQuota holds the schema definition for per-user per-platform quota.
type UserPlatformQuota struct {
ent.Schema
}
func (UserPlatformQuota) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "user_platform_quotas"},
}
}
func (UserPlatformQuota) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
mixins.SoftDeleteMixin{},
}
}
func (UserPlatformQuota) Fields() []ent.Field {
return []ent.Field{
field.Int64("user_id"),
field.String("platform").
MaxLen(32).
NotEmpty().
Validate(func(s string) error {
// 注意:平台列表的单一权威源为 service.AllowedQuotaPlatforms
// 此处为 ent 构建期约束,需与 service.AllowedQuotaPlatforms 保持同步。
switch s {
case "anthropic", "openai", "gemini", "antigravity":
return nil
default:
return fmt.Errorf("platform %q is not allowed", s)
}
}),
// 日 / 周 / 月 USD 上限:
// nil / not set → 无限额(完全放行)
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
// > 0 → USD 限额上限
field.Float("daily_limit_usd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("weekly_limit_usd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("monthly_limit_usd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
// 当前窗口已用量USDpreflight 时与 limit 比较)
field.Float("daily_usage_usd").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("weekly_usage_usd").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("monthly_usage_usd").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
// 窗口起点NULL = 首次还未初始化,由 InitWindowStarts 用 COALESCE 兜底)
field.Time("daily_window_start").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("weekly_window_start").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("monthly_window_start").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (UserPlatformQuota) Edges() []ent.Edge {
return []ent.Edge{
edge.From("user", User.Type).
Ref("platform_quotas").
Field("user_id").
Unique().
Required(),
}
}
func (UserPlatformQuota) Indexes() []ent.Index {
return []ent.Index{
// 软删除友好:只对未删记录唯一
index.Fields("user_id", "platform").
Unique().
Annotations(entsql.IndexWhere("deleted_at IS NULL")),
index.Fields("user_id"),
}
}

View File

@ -80,6 +80,8 @@ type Tx struct {
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
UserPlatformQuota *UserPlatformQuotaClient
// UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient
@ -246,6 +248,7 @@ func (tx *Tx) init() {
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
tx.UserPlatformQuota = NewUserPlatformQuotaClient(tx.config)
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
}

View File

@ -95,11 +95,13 @@ type UserEdges struct {
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// PlatformQuotas holds the value of the platform_quotas edge.
PlatformQuotas []*UserPlatformQuota `json:"platform_quotas,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [13]bool
loadedTypes [14]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@ -210,10 +212,19 @@ func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
}
// PlatformQuotasOrErr returns the PlatformQuotas value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PlatformQuotasOrErr() ([]*UserPlatformQuota, error) {
if e.loadedTypes[12] {
return e.PlatformQuotas, nil
}
return nil, &NotLoadedError{edge: "platform_quotas"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[12] {
if e.loadedTypes[13] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@ -472,6 +483,11 @@ func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
}
// QueryPlatformQuotas queries the "platform_quotas" edge of the User entity.
func (_m *User) QueryPlatformQuotas() *UserPlatformQuotaQuery {
return NewUserClient(_m.config).QueryPlatformQuotas(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)

View File

@ -85,6 +85,8 @@ const (
EdgeAuthIdentities = "auth_identities"
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
EdgePendingAuthSessions = "pending_auth_sessions"
// EdgePlatformQuotas holds the string denoting the platform_quotas edge name in mutations.
EdgePlatformQuotas = "platform_quotas"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@ -171,6 +173,13 @@ const (
PendingAuthSessionsInverseTable = "pending_auth_sessions"
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
PendingAuthSessionsColumn = "target_user_id"
// PlatformQuotasTable is the table that holds the platform_quotas relation/edge.
PlatformQuotasTable = "user_platform_quotas"
// PlatformQuotasInverseTable is the table name for the UserPlatformQuota entity.
// It exists in this package in order to avoid circular dependency with the "userplatformquota" package.
PlatformQuotasInverseTable = "user_platform_quotas"
// PlatformQuotasColumn is the table column denoting the platform_quotas relation/edge.
PlatformQuotasColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@ -569,6 +578,20 @@ func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpti
}
}
// ByPlatformQuotasCount orders the results by platform_quotas count.
func ByPlatformQuotasCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newPlatformQuotasStep(), opts...)
}
}
// ByPlatformQuotas orders the results by platform_quotas terms.
func ByPlatformQuotas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPlatformQuotasStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@ -666,6 +689,13 @@ func newPendingAuthSessionsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
}
func newPlatformQuotasStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PlatformQuotasInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@ -1616,6 +1616,29 @@ func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate
})
}
// HasPlatformQuotas applies the HasEdge predicate on the "platform_quotas" edge.
func HasPlatformQuotas() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPlatformQuotasWith applies the HasEdge predicate on the "platform_quotas" edge with a given conditions (other predicates).
func HasPlatformQuotasWith(preds ...predicate.UserPlatformQuota) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPlatformQuotasStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -519,6 +520,21 @@ func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCrea
return _c.AddPendingAuthSessionIDs(ids...)
}
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
func (_c *UserCreate) AddPlatformQuotaIDs(ids ...int64) *UserCreate {
_c.mutation.AddPlatformQuotaIDs(ids...)
return _c
}
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
func (_c *UserCreate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPlatformQuotaIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@ -1023,6 +1039,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}

View File

@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -48,6 +49,7 @@ type UserQuery struct {
withPaymentOrders *PaymentOrderQuery
withAuthIdentities *AuthIdentityQuery
withPendingAuthSessions *PendingAuthSessionQuery
withPlatformQuotas *UserPlatformQuotaQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@ -350,6 +352,28 @@ func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return query
}
// QueryPlatformQuotas chains the current query on the "platform_quotas" edge.
func (_q *UserQuery) QueryPlatformQuotas() *UserPlatformQuotaQuery {
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@ -576,6 +600,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withPaymentOrders: _q.withPaymentOrders.Clone(),
withAuthIdentities: _q.withAuthIdentities.Clone(),
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withPlatformQuotas: _q.withPlatformQuotas.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@ -715,6 +740,17 @@ func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQue
return _q
}
// WithPlatformQuotas tells the query-builder to eager-load the nodes that are connected to
// the "platform_quotas" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPlatformQuotas(opts ...func(*UserPlatformQuotaQuery)) *UserQuery {
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPlatformQuotas = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@ -804,7 +840,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [13]bool{
loadedTypes = [14]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@ -817,6 +853,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withPaymentOrders != nil,
_q.withAuthIdentities != nil,
_q.withPendingAuthSessions != nil,
_q.withPlatformQuotas != nil,
_q.withUserAllowedGroups != nil,
}
)
@ -929,6 +966,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withPlatformQuotas; query != nil {
if err := _q.loadPlatformQuotas(ctx, query, nodes,
func(n *User) { n.Edges.PlatformQuotas = []*UserPlatformQuota{} },
func(n *User, e *UserPlatformQuota) { n.Edges.PlatformQuotas = append(n.Edges.PlatformQuotas, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@ -1339,6 +1383,36 @@ func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *Pending
}
return nil
}
func (_q *UserQuery) loadPlatformQuotas(ctx context.Context, query *UserPlatformQuotaQuery, nodes []*User, init func(*User), assign func(*User, *UserPlatformQuota)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(userplatformquota.FieldUserID)
}
query.Where(predicate.UserPlatformQuota(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PlatformQuotasColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)

View File

@ -23,6 +23,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -590,6 +591,21 @@ func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpda
return _u.AddPendingAuthSessionIDs(ids...)
}
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
func (_u *UserUpdate) AddPlatformQuotaIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPlatformQuotaIDs(ids...)
return _u
}
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPlatformQuotaIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@ -847,6 +863,27 @@ func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserU
return _u.RemovePendingAuthSessionIDs(ids...)
}
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdate) ClearPlatformQuotas() *UserUpdate {
_u.mutation.ClearPlatformQuotas()
return _u
}
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
func (_u *UserUpdate) RemovePlatformQuotaIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePlatformQuotaIDs(ids...)
return _u
}
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
func (_u *UserUpdate) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePlatformQuotaIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@ -1587,6 +1624,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@ -2158,6 +2240,21 @@ func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserU
return _u.AddPendingAuthSessionIDs(ids...)
}
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
func (_u *UserUpdateOne) AddPlatformQuotaIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPlatformQuotaIDs(ids...)
return _u
}
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdateOne) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPlatformQuotaIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@ -2415,6 +2512,27 @@ func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *Us
return _u.RemovePendingAuthSessionIDs(ids...)
}
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdateOne) ClearPlatformQuotas() *UserUpdateOne {
_u.mutation.ClearPlatformQuotas()
return _u
}
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
func (_u *UserUpdateOne) RemovePlatformQuotaIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePlatformQuotaIDs(ids...)
return _u
}
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
func (_u *UserUpdateOne) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePlatformQuotaIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@ -3185,6 +3303,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@ -0,0 +1,301 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuota is the model entity for the UserPlatformQuota schema.
type UserPlatformQuota struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// Platform holds the value of the "platform" field.
Platform string `json:"platform,omitempty"`
// DailyLimitUsd holds the value of the "daily_limit_usd" field.
DailyLimitUsd *float64 `json:"daily_limit_usd,omitempty"`
// WeeklyLimitUsd holds the value of the "weekly_limit_usd" field.
WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"`
// MonthlyLimitUsd holds the value of the "monthly_limit_usd" field.
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
// DailyUsageUsd holds the value of the "daily_usage_usd" field.
DailyUsageUsd float64 `json:"daily_usage_usd,omitempty"`
// WeeklyUsageUsd holds the value of the "weekly_usage_usd" field.
WeeklyUsageUsd float64 `json:"weekly_usage_usd,omitempty"`
// MonthlyUsageUsd holds the value of the "monthly_usage_usd" field.
MonthlyUsageUsd float64 `json:"monthly_usage_usd,omitempty"`
// DailyWindowStart holds the value of the "daily_window_start" field.
DailyWindowStart *time.Time `json:"daily_window_start,omitempty"`
// WeeklyWindowStart holds the value of the "weekly_window_start" field.
WeeklyWindowStart *time.Time `json:"weekly_window_start,omitempty"`
// MonthlyWindowStart holds the value of the "monthly_window_start" field.
MonthlyWindowStart *time.Time `json:"monthly_window_start,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserPlatformQuotaQuery when eager-loading is set.
Edges UserPlatformQuotaEdges `json:"edges"`
selectValues sql.SelectValues
}
// UserPlatformQuotaEdges holds the relations/edges for other nodes in the graph.
type UserPlatformQuotaEdges struct {
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UserPlatformQuotaEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserPlatformQuota) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case userplatformquota.FieldDailyLimitUsd, userplatformquota.FieldWeeklyLimitUsd, userplatformquota.FieldMonthlyLimitUsd, userplatformquota.FieldDailyUsageUsd, userplatformquota.FieldWeeklyUsageUsd, userplatformquota.FieldMonthlyUsageUsd:
values[i] = new(sql.NullFloat64)
case userplatformquota.FieldID, userplatformquota.FieldUserID:
values[i] = new(sql.NullInt64)
case userplatformquota.FieldPlatform:
values[i] = new(sql.NullString)
case userplatformquota.FieldCreatedAt, userplatformquota.FieldUpdatedAt, userplatformquota.FieldDeletedAt, userplatformquota.FieldDailyWindowStart, userplatformquota.FieldWeeklyWindowStart, userplatformquota.FieldMonthlyWindowStart:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UserPlatformQuota fields.
func (_m *UserPlatformQuota) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case userplatformquota.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case userplatformquota.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case userplatformquota.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case userplatformquota.FieldDeletedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
} else if value.Valid {
_m.DeletedAt = new(time.Time)
*_m.DeletedAt = value.Time
}
case userplatformquota.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case userplatformquota.FieldPlatform:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field platform", values[i])
} else if value.Valid {
_m.Platform = value.String
}
case userplatformquota.FieldDailyLimitUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field daily_limit_usd", values[i])
} else if value.Valid {
_m.DailyLimitUsd = new(float64)
*_m.DailyLimitUsd = value.Float64
}
case userplatformquota.FieldWeeklyLimitUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field weekly_limit_usd", values[i])
} else if value.Valid {
_m.WeeklyLimitUsd = new(float64)
*_m.WeeklyLimitUsd = value.Float64
}
case userplatformquota.FieldMonthlyLimitUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field monthly_limit_usd", values[i])
} else if value.Valid {
_m.MonthlyLimitUsd = new(float64)
*_m.MonthlyLimitUsd = value.Float64
}
case userplatformquota.FieldDailyUsageUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field daily_usage_usd", values[i])
} else if value.Valid {
_m.DailyUsageUsd = value.Float64
}
case userplatformquota.FieldWeeklyUsageUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field weekly_usage_usd", values[i])
} else if value.Valid {
_m.WeeklyUsageUsd = value.Float64
}
case userplatformquota.FieldMonthlyUsageUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field monthly_usage_usd", values[i])
} else if value.Valid {
_m.MonthlyUsageUsd = value.Float64
}
case userplatformquota.FieldDailyWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field daily_window_start", values[i])
} else if value.Valid {
_m.DailyWindowStart = new(time.Time)
*_m.DailyWindowStart = value.Time
}
case userplatformquota.FieldWeeklyWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field weekly_window_start", values[i])
} else if value.Valid {
_m.WeeklyWindowStart = new(time.Time)
*_m.WeeklyWindowStart = value.Time
}
case userplatformquota.FieldMonthlyWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field monthly_window_start", values[i])
} else if value.Valid {
_m.MonthlyWindowStart = new(time.Time)
*_m.MonthlyWindowStart = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UserPlatformQuota.
// This includes values selected through modifiers, order, etc.
func (_m *UserPlatformQuota) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryUser queries the "user" edge of the UserPlatformQuota entity.
func (_m *UserPlatformQuota) QueryUser() *UserQuery {
return NewUserPlatformQuotaClient(_m.config).QueryUser(_m)
}
// Update returns a builder for updating this UserPlatformQuota.
// Note that you need to call UserPlatformQuota.Unwrap() before calling this method if this UserPlatformQuota
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UserPlatformQuota) Update() *UserPlatformQuotaUpdateOne {
return NewUserPlatformQuotaClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UserPlatformQuota entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UserPlatformQuota) Unwrap() *UserPlatformQuota {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UserPlatformQuota is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UserPlatformQuota) String() string {
var builder strings.Builder
builder.WriteString("UserPlatformQuota(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.DeletedAt; v != nil {
builder.WriteString("deleted_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("platform=")
builder.WriteString(_m.Platform)
builder.WriteString(", ")
if v := _m.DailyLimitUsd; v != nil {
builder.WriteString("daily_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.WeeklyLimitUsd; v != nil {
builder.WriteString("weekly_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.MonthlyLimitUsd; v != nil {
builder.WriteString("monthly_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("daily_usage_usd=")
builder.WriteString(fmt.Sprintf("%v", _m.DailyUsageUsd))
builder.WriteString(", ")
builder.WriteString("weekly_usage_usd=")
builder.WriteString(fmt.Sprintf("%v", _m.WeeklyUsageUsd))
builder.WriteString(", ")
builder.WriteString("monthly_usage_usd=")
builder.WriteString(fmt.Sprintf("%v", _m.MonthlyUsageUsd))
builder.WriteString(", ")
if v := _m.DailyWindowStart; v != nil {
builder.WriteString("daily_window_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.WeeklyWindowStart; v != nil {
builder.WriteString("weekly_window_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.MonthlyWindowStart; v != nil {
builder.WriteString("monthly_window_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}
// UserPlatformQuotaSlice is a parsable slice of UserPlatformQuota.
type UserPlatformQuotaSlice []*UserPlatformQuota

View File

@ -0,0 +1,202 @@
// Code generated by ent, DO NOT EDIT.
package userplatformquota
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the userplatformquota type in the database.
Label = "user_platform_quota"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldPlatform holds the string denoting the platform field in the database.
FieldPlatform = "platform"
// FieldDailyLimitUsd holds the string denoting the daily_limit_usd field in the database.
FieldDailyLimitUsd = "daily_limit_usd"
// FieldWeeklyLimitUsd holds the string denoting the weekly_limit_usd field in the database.
FieldWeeklyLimitUsd = "weekly_limit_usd"
// FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database.
FieldMonthlyLimitUsd = "monthly_limit_usd"
// FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database.
FieldDailyUsageUsd = "daily_usage_usd"
// FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database.
FieldWeeklyUsageUsd = "weekly_usage_usd"
// FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database.
FieldMonthlyUsageUsd = "monthly_usage_usd"
// FieldDailyWindowStart holds the string denoting the daily_window_start field in the database.
FieldDailyWindowStart = "daily_window_start"
// FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database.
FieldWeeklyWindowStart = "weekly_window_start"
// FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database.
FieldMonthlyWindowStart = "monthly_window_start"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// Table holds the table name of the userplatformquota in the database.
Table = "user_platform_quotas"
// UserTable is the table that holds the user relation/edge.
UserTable = "user_platform_quotas"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
)
// Columns holds all SQL columns for userplatformquota fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldDeletedAt,
FieldUserID,
FieldPlatform,
FieldDailyLimitUsd,
FieldWeeklyLimitUsd,
FieldMonthlyLimitUsd,
FieldDailyUsageUsd,
FieldWeeklyUsageUsd,
FieldMonthlyUsageUsd,
FieldDailyWindowStart,
FieldWeeklyWindowStart,
FieldMonthlyWindowStart,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
PlatformValidator func(string) error
// DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field.
DefaultDailyUsageUsd float64
// DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field.
DefaultWeeklyUsageUsd float64
// DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field.
DefaultMonthlyUsageUsd float64
)
// OrderOption defines the ordering options for the UserPlatformQuota queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByPlatform orders the results by the platform field.
func ByPlatform(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPlatform, opts...).ToFunc()
}
// ByDailyLimitUsd orders the results by the daily_limit_usd field.
func ByDailyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyLimitUsd, opts...).ToFunc()
}
// ByWeeklyLimitUsd orders the results by the weekly_limit_usd field.
func ByWeeklyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyLimitUsd, opts...).ToFunc()
}
// ByMonthlyLimitUsd orders the results by the monthly_limit_usd field.
func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc()
}
// ByDailyUsageUsd orders the results by the daily_usage_usd field.
func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc()
}
// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field.
func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc()
}
// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field.
func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc()
}
// ByDailyWindowStart orders the results by the daily_window_start field.
func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc()
}
// ByWeeklyWindowStart orders the results by the weekly_window_start field.
func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc()
}
// ByMonthlyWindowStart orders the results by the monthly_window_start field.
func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}

View File

@ -0,0 +1,799 @@
// Code generated by ent, DO NOT EDIT.
package userplatformquota
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
}
// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ.
func Platform(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
}
// DailyLimitUsd applies equality check predicate on the "daily_limit_usd" field. It's identical to DailyLimitUsdEQ.
func DailyLimitUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
}
// WeeklyLimitUsd applies equality check predicate on the "weekly_limit_usd" field. It's identical to WeeklyLimitUsdEQ.
func WeeklyLimitUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
}
// MonthlyLimitUsd applies equality check predicate on the "monthly_limit_usd" field. It's identical to MonthlyLimitUsdEQ.
func MonthlyLimitUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
}
// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ.
func DailyUsageUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
}
// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ.
func WeeklyUsageUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
}
// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ.
func MonthlyUsageUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
}
// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ.
func DailyWindowStart(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
}
// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ.
func WeeklyWindowStart(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
}
// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ.
func MonthlyWindowStart(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldUpdatedAt, v))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDeletedAt))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUserID, vs...))
}
// PlatformEQ applies the EQ predicate on the "platform" field.
func PlatformEQ(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
}
// PlatformNEQ applies the NEQ predicate on the "platform" field.
func PlatformNEQ(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldPlatform, v))
}
// PlatformIn applies the In predicate on the "platform" field.
func PlatformIn(vs ...string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldPlatform, vs...))
}
// PlatformNotIn applies the NotIn predicate on the "platform" field.
func PlatformNotIn(vs ...string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldPlatform, vs...))
}
// PlatformGT applies the GT predicate on the "platform" field.
func PlatformGT(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldPlatform, v))
}
// PlatformGTE applies the GTE predicate on the "platform" field.
func PlatformGTE(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldPlatform, v))
}
// PlatformLT applies the LT predicate on the "platform" field.
func PlatformLT(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldPlatform, v))
}
// PlatformLTE applies the LTE predicate on the "platform" field.
func PlatformLTE(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldPlatform, v))
}
// PlatformContains applies the Contains predicate on the "platform" field.
func PlatformContains(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldContains(FieldPlatform, v))
}
// PlatformHasPrefix applies the HasPrefix predicate on the "platform" field.
func PlatformHasPrefix(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldHasPrefix(FieldPlatform, v))
}
// PlatformHasSuffix applies the HasSuffix predicate on the "platform" field.
func PlatformHasSuffix(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldHasSuffix(FieldPlatform, v))
}
// PlatformEqualFold applies the EqualFold predicate on the "platform" field.
func PlatformEqualFold(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEqualFold(FieldPlatform, v))
}
// PlatformContainsFold applies the ContainsFold predicate on the "platform" field.
func PlatformContainsFold(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldContainsFold(FieldPlatform, v))
}
// DailyLimitUsdEQ applies the EQ predicate on the "daily_limit_usd" field.
func DailyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
}
// DailyLimitUsdNEQ applies the NEQ predicate on the "daily_limit_usd" field.
func DailyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyLimitUsd, v))
}
// DailyLimitUsdIn applies the In predicate on the "daily_limit_usd" field.
func DailyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyLimitUsd, vs...))
}
// DailyLimitUsdNotIn applies the NotIn predicate on the "daily_limit_usd" field.
func DailyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyLimitUsd, vs...))
}
// DailyLimitUsdGT applies the GT predicate on the "daily_limit_usd" field.
func DailyLimitUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyLimitUsd, v))
}
// DailyLimitUsdGTE applies the GTE predicate on the "daily_limit_usd" field.
func DailyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyLimitUsd, v))
}
// DailyLimitUsdLT applies the LT predicate on the "daily_limit_usd" field.
func DailyLimitUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyLimitUsd, v))
}
// DailyLimitUsdLTE applies the LTE predicate on the "daily_limit_usd" field.
func DailyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyLimitUsd, v))
}
// DailyLimitUsdIsNil applies the IsNil predicate on the "daily_limit_usd" field.
func DailyLimitUsdIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyLimitUsd))
}
// DailyLimitUsdNotNil applies the NotNil predicate on the "daily_limit_usd" field.
func DailyLimitUsdNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyLimitUsd))
}
// WeeklyLimitUsdEQ applies the EQ predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdNEQ applies the NEQ predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdIn applies the In predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyLimitUsd, vs...))
}
// WeeklyLimitUsdNotIn applies the NotIn predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyLimitUsd, vs...))
}
// WeeklyLimitUsdGT applies the GT predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdGTE applies the GTE predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdLT applies the LT predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdLTE applies the LTE predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdIsNil applies the IsNil predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyLimitUsd))
}
// WeeklyLimitUsdNotNil applies the NotNil predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyLimitUsd))
}
// MonthlyLimitUsdEQ applies the EQ predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdNEQ applies the NEQ predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdIn applies the In predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyLimitUsd, vs...))
}
// MonthlyLimitUsdNotIn applies the NotIn predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyLimitUsd, vs...))
}
// MonthlyLimitUsdGT applies the GT predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdGTE applies the GTE predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdLT applies the LT predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdLTE applies the LTE predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdIsNil applies the IsNil predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyLimitUsd))
}
// MonthlyLimitUsdNotNil applies the NotNil predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyLimitUsd))
}
// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field.
func DailyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
}
// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field.
func DailyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyUsageUsd, v))
}
// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field.
func DailyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyUsageUsd, vs...))
}
// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field.
func DailyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyUsageUsd, vs...))
}
// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field.
func DailyUsageUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyUsageUsd, v))
}
// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field.
func DailyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyUsageUsd, v))
}
// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field.
func DailyUsageUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyUsageUsd, v))
}
// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field.
func DailyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyUsageUsd, v))
}
// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyUsageUsd, vs...))
}
// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...))
}
// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyUsageUsd, v))
}
// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyUsageUsd, vs...))
}
// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...))
}
// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyUsageUsd, v))
}
// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field.
func DailyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
}
// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field.
func DailyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyWindowStart, v))
}
// DailyWindowStartIn applies the In predicate on the "daily_window_start" field.
func DailyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyWindowStart, vs...))
}
// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field.
func DailyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyWindowStart, vs...))
}
// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field.
func DailyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyWindowStart, v))
}
// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field.
func DailyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyWindowStart, v))
}
// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field.
func DailyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyWindowStart, v))
}
// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field.
func DailyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyWindowStart, v))
}
// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field.
func DailyWindowStartIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyWindowStart))
}
// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field.
func DailyWindowStartNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyWindowStart))
}
// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field.
func WeeklyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field.
func WeeklyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field.
func WeeklyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyWindowStart, vs...))
}
// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field.
func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyWindowStart, vs...))
}
// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field.
func WeeklyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field.
func WeeklyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field.
func WeeklyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field.
func WeeklyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field.
func WeeklyWindowStartIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyWindowStart))
}
// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field.
func WeeklyWindowStartNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyWindowStart))
}
// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field.
func MonthlyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field.
func MonthlyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field.
func MonthlyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyWindowStart, vs...))
}
// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field.
func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyWindowStart, vs...))
}
// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field.
func MonthlyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field.
func MonthlyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field.
func MonthlyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field.
func MonthlyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field.
func MonthlyWindowStartIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyWindowStart))
}
// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field.
func MonthlyWindowStartNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyWindowStart))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserPlatformQuota) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuotaDelete is the builder for deleting a UserPlatformQuota entity.
type UserPlatformQuotaDelete struct {
config
hooks []Hook
mutation *UserPlatformQuotaMutation
}
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
func (_d *UserPlatformQuotaDelete) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserPlatformQuotaDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserPlatformQuotaDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserPlatformQuotaDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserPlatformQuotaDeleteOne is the builder for deleting a single UserPlatformQuota entity.
type UserPlatformQuotaDeleteOne struct {
_d *UserPlatformQuotaDelete
}
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
func (_d *UserPlatformQuotaDeleteOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserPlatformQuotaDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{userplatformquota.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserPlatformQuotaDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,643 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuotaQuery is the builder for querying UserPlatformQuota entities.
type UserPlatformQuotaQuery struct {
config
ctx *QueryContext
order []userplatformquota.OrderOption
inters []Interceptor
predicates []predicate.UserPlatformQuota
withUser *UserQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserPlatformQuotaQuery builder.
func (_q *UserPlatformQuotaQuery) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserPlatformQuotaQuery) Limit(limit int) *UserPlatformQuotaQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserPlatformQuotaQuery) Offset(offset int) *UserPlatformQuotaQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserPlatformQuotaQuery) Unique(unique bool) *UserPlatformQuotaQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserPlatformQuotaQuery) Order(o ...userplatformquota.OrderOption) *UserPlatformQuotaQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UserPlatformQuotaQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserPlatformQuota entity from the query.
// Returns a *NotFoundError when no UserPlatformQuota was found.
func (_q *UserPlatformQuotaQuery) First(ctx context.Context) (*UserPlatformQuota, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{userplatformquota.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) FirstX(ctx context.Context) *UserPlatformQuota {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserPlatformQuota ID from the query.
// Returns a *NotFoundError when no UserPlatformQuota ID was found.
func (_q *UserPlatformQuotaQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{userplatformquota.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserPlatformQuota entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserPlatformQuota entity is found.
// Returns a *NotFoundError when no UserPlatformQuota entities are found.
func (_q *UserPlatformQuotaQuery) Only(ctx context.Context) (*UserPlatformQuota, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{userplatformquota.Label}
default:
return nil, &NotSingularError{userplatformquota.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) OnlyX(ctx context.Context) *UserPlatformQuota {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserPlatformQuota ID in the query.
// Returns a *NotSingularError when more than one UserPlatformQuota ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserPlatformQuotaQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{userplatformquota.Label}
default:
err = &NotSingularError{userplatformquota.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserPlatformQuotaSlice.
func (_q *UserPlatformQuotaQuery) All(ctx context.Context) ([]*UserPlatformQuota, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserPlatformQuota, *UserPlatformQuotaQuery]()
return withInterceptors[[]*UserPlatformQuota](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) AllX(ctx context.Context) []*UserPlatformQuota {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserPlatformQuota IDs.
func (_q *UserPlatformQuotaQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(userplatformquota.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserPlatformQuotaQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserPlatformQuotaQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserPlatformQuotaQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserPlatformQuotaQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserPlatformQuotaQuery) Clone() *UserPlatformQuotaQuery {
if _q == nil {
return nil
}
return &UserPlatformQuotaQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]userplatformquota.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserPlatformQuota{}, _q.predicates...),
withUser: _q.withUser.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserPlatformQuotaQuery) WithUser(opts ...func(*UserQuery)) *UserPlatformQuotaQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserPlatformQuota.Query().
// GroupBy(userplatformquota.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserPlatformQuotaQuery) GroupBy(field string, fields ...string) *UserPlatformQuotaGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserPlatformQuotaGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = userplatformquota.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserPlatformQuota.Query().
// Select(userplatformquota.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserPlatformQuotaQuery) Select(fields ...string) *UserPlatformQuotaSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserPlatformQuotaSelect{UserPlatformQuotaQuery: _q}
sbuild.label = userplatformquota.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserPlatformQuotaSelect configured with the given aggregations.
func (_q *UserPlatformQuotaQuery) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserPlatformQuotaQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !userplatformquota.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserPlatformQuotaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserPlatformQuota, error) {
var (
nodes = []*UserPlatformQuota{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withUser != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserPlatformQuota).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserPlatformQuota{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UserPlatformQuota, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserPlatformQuotaQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserPlatformQuota, init func(*UserPlatformQuota), assign func(*UserPlatformQuota, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserPlatformQuota)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserPlatformQuotaQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserPlatformQuotaQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
for i := range fields {
if fields[i] != userplatformquota.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(userplatformquota.FieldUserID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserPlatformQuotaQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(userplatformquota.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = userplatformquota.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserPlatformQuotaQuery) ForUpdate(opts ...sql.LockOption) *UserPlatformQuotaQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserPlatformQuotaQuery) ForShare(opts ...sql.LockOption) *UserPlatformQuotaQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserPlatformQuotaGroupBy is the group-by builder for UserPlatformQuota entities.
type UserPlatformQuotaGroupBy struct {
selector
build *UserPlatformQuotaQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserPlatformQuotaGroupBy) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserPlatformQuotaGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserPlatformQuotaGroupBy) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserPlatformQuotaSelect is the builder for selecting fields of UserPlatformQuota entities.
type UserPlatformQuotaSelect struct {
*UserPlatformQuotaQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserPlatformQuotaSelect) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserPlatformQuotaSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaSelect](ctx, _s.UserPlatformQuotaQuery, _s, _s.inters, v)
}
func (_s *UserPlatformQuotaSelect) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@ -0,0 +1,985 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuotaUpdate is the builder for updating UserPlatformQuota entities.
type UserPlatformQuotaUpdate struct {
config
hooks []Hook
mutation *UserPlatformQuotaMutation
}
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
func (_u *UserPlatformQuotaUpdate) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserPlatformQuotaUpdate) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserPlatformQuotaUpdate) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserPlatformQuotaUpdate) ClearDeletedAt() *UserPlatformQuotaUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserPlatformQuotaUpdate) SetUserID(v int64) *UserPlatformQuotaUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableUserID(v *int64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetPlatform sets the "platform" field.
func (_u *UserPlatformQuotaUpdate) SetPlatform(v string) *UserPlatformQuotaUpdate {
_u.mutation.SetPlatform(v)
return _u
}
// SetNillablePlatform sets the "platform" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillablePlatform(v *string) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetPlatform(*v)
}
return _u
}
// SetDailyLimitUsd sets the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetDailyLimitUsd()
_u.mutation.SetDailyLimitUsd(v)
return _u
}
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDailyLimitUsd(*v)
}
return _u
}
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddDailyLimitUsd(v)
return _u
}
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) ClearDailyLimitUsd() *UserPlatformQuotaUpdate {
_u.mutation.ClearDailyLimitUsd()
return _u
}
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetWeeklyLimitUsd()
_u.mutation.SetWeeklyLimitUsd(v)
return _u
}
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetWeeklyLimitUsd(*v)
}
return _u
}
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddWeeklyLimitUsd(v)
return _u
}
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdate {
_u.mutation.ClearWeeklyLimitUsd()
return _u
}
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetMonthlyLimitUsd()
_u.mutation.SetMonthlyLimitUsd(v)
return _u
}
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetMonthlyLimitUsd(*v)
}
return _u
}
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddMonthlyLimitUsd(v)
return _u
}
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdate {
_u.mutation.ClearMonthlyLimitUsd()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdate) ClearDailyWindowStart() *UserPlatformQuotaUpdate {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdate) ClearWeeklyWindowStart() *UserPlatformQuotaUpdate {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdate) ClearMonthlyWindowStart() *UserPlatformQuotaUpdate {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdate) SetUser(v *User) *UserPlatformQuotaUpdate {
return _u.SetUserID(v.ID)
}
// Mutation returns the UserPlatformQuotaMutation object of the builder.
func (_u *UserPlatformQuotaUpdate) Mutation() *UserPlatformQuotaMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdate) ClearUser() *UserPlatformQuotaUpdate {
_u.mutation.ClearUser()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserPlatformQuotaUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserPlatformQuotaUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserPlatformQuotaUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userplatformquota.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userplatformquota.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserPlatformQuotaUpdate) check() error {
if v, ok := _u.mutation.Platform(); ok {
if err := userplatformquota.PlatformValidator(v); err != nil {
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
}
return nil
}
func (_u *UserPlatformQuotaUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Platform(); ok {
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
}
if value, ok := _u.mutation.DailyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.DailyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.WeeklyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.MonthlyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userplatformquota.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserPlatformQuotaUpdateOne is the builder for updating a single UserPlatformQuota entity.
type UserPlatformQuotaUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserPlatformQuotaMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserPlatformQuotaUpdateOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserPlatformQuotaUpdateOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserPlatformQuotaUpdateOne) ClearDeletedAt() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserPlatformQuotaUpdateOne) SetUserID(v int64) *UserPlatformQuotaUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableUserID(v *int64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetPlatform sets the "platform" field.
func (_u *UserPlatformQuotaUpdateOne) SetPlatform(v string) *UserPlatformQuotaUpdateOne {
_u.mutation.SetPlatform(v)
return _u
}
// SetNillablePlatform sets the "platform" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillablePlatform(v *string) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetPlatform(*v)
}
return _u
}
// SetDailyLimitUsd sets the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetDailyLimitUsd()
_u.mutation.SetDailyLimitUsd(v)
return _u
}
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDailyLimitUsd(*v)
}
return _u
}
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddDailyLimitUsd(v)
return _u
}
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) ClearDailyLimitUsd() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearDailyLimitUsd()
return _u
}
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetWeeklyLimitUsd()
_u.mutation.SetWeeklyLimitUsd(v)
return _u
}
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetWeeklyLimitUsd(*v)
}
return _u
}
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddWeeklyLimitUsd(v)
return _u
}
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearWeeklyLimitUsd()
return _u
}
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetMonthlyLimitUsd()
_u.mutation.SetMonthlyLimitUsd(v)
return _u
}
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetMonthlyLimitUsd(*v)
}
return _u
}
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddMonthlyLimitUsd(v)
return _u
}
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearMonthlyLimitUsd()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) ClearDailyWindowStart() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdateOne) SetUser(v *User) *UserPlatformQuotaUpdateOne {
return _u.SetUserID(v.ID)
}
// Mutation returns the UserPlatformQuotaMutation object of the builder.
func (_u *UserPlatformQuotaUpdateOne) Mutation() *UserPlatformQuotaMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdateOne) ClearUser() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearUser()
return _u
}
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
func (_u *UserPlatformQuotaUpdateOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserPlatformQuotaUpdateOne) Select(field string, fields ...string) *UserPlatformQuotaUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserPlatformQuota entity.
func (_u *UserPlatformQuotaUpdateOne) Save(ctx context.Context) (*UserPlatformQuota, error) {
if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdateOne) SaveX(ctx context.Context) *UserPlatformQuota {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserPlatformQuotaUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserPlatformQuotaUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userplatformquota.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userplatformquota.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserPlatformQuotaUpdateOne) check() error {
if v, ok := _u.mutation.Platform(); ok {
if err := userplatformquota.PlatformValidator(v); err != nil {
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
}
return nil
}
func (_u *UserPlatformQuotaUpdateOne) sqlSave(ctx context.Context) (_node *UserPlatformQuota, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserPlatformQuota.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
for _, f := range fields {
if !userplatformquota.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != userplatformquota.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Platform(); ok {
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
}
if value, ok := _u.mutation.DailyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.DailyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.WeeklyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.MonthlyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserPlatformQuota{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userplatformquota.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@ -6,12 +6,14 @@ require (
connectrpc.com/connect v1.19.2
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alicebob/miniredis/v2 v2.38.0
github.com/alitto/pond/v2 v2.6.2
github.com/andybalholm/brotli v1.2.0
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
github.com/aws/smithy-go v1.24.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/coder/websocket v1.8.14
github.com/dgraph-io/ristretto v0.2.0
@ -74,7 +76,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
@ -160,6 +161,7 @@ require (
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zclconf/go-cty v1.14.4 // indirect
github.com/zclconf/go-cty-yaml v1.1.0 // indirect

View File

@ -18,6 +18,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw=
github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw=
github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw=
github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
@ -362,6 +364,8 @@ github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52H
github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8=

View File

@ -651,6 +651,12 @@ type ProxyProbeConfig struct {
type BillingConfig struct {
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
// UserPlatformQuotaCacheTTLSeconds 用户 × 平台 quota 缓存 TTL默认 86400=1天覆盖典型 daily 窗口。
// 消费点:
// - billing_cache_service.cacheWriteWorker 异步累加
// - billing_cache_service.checkUserPlatformQuotaEligibility 首次缓存装载
// 读写两端必须共用同一 TTL避免缓存生命周期不一致导致 quota 计数漂移。
UserPlatformQuotaCacheTTLSeconds int `mapstructure:"user_platform_quota_cache_ttl_seconds"`
}
type CircuitBreakerConfig struct {
@ -688,6 +694,9 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// OpenAIResponseHeaderTimeout: OpenAI/Codex 上游等待响应头的超时时间0表示无超时
// OpenAI/Codex 请求可能在上游排队较久;默认不使用通用响应头超时截断。
OpenAIResponseHeaderTimeout int `mapstructure:"openai_response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
@ -717,6 +726,8 @@ type GatewayConfig struct {
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// NodeTLSProxy: Node.js TLS 代理配置
NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"`
// OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2可按代理能力回退 HTTP/1.1
OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"`
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
@ -815,6 +826,21 @@ type GatewayConfig struct {
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
}
// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。
// 默认启用 HTTP/2在部分代理不兼容时按策略回退 HTTP/1.1。
type GatewayOpenAIHTTP2Config struct {
// Enabled: 是否启用 OpenAI HTTP/2 优先策略
Enabled bool `mapstructure:"enabled"`
// AllowProxyFallbackToHTTP1: HTTP/HTTPS 代理出现明确 H2 兼容错误时,临时回退 HTTP/1.1
AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"`
// FallbackErrorThreshold: 回退窗口内累计多少次兼容错误后触发回退
FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"`
// FallbackWindowSeconds: 统计兼容错误的时间窗口(秒)
FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"`
// FallbackTTLSeconds: 触发后回退 HTTP/1.1 的持续时间(秒)
FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"`
}
// UserMessageQueueConfig 用户消息串行队列配置
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
type UserMessageQueueConfig struct {
@ -1647,6 +1673,7 @@ func setDefaults() {
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
viper.SetDefault("billing.user_platform_quota_cache_ttl_seconds", 86400)
// Turnstile
viper.SetDefault("turnstile.required", false)
@ -1847,6 +1874,7 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.openai_response_header_timeout", 0)
viper.SetDefault("gateway.log_upstream_error_body", true)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
@ -1902,6 +1930,12 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
// OpenAI HTTP upstream protocol strategy
viper.SetDefault("gateway.openai_http2.enabled", true)
viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true)
viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2)
viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60)
viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600)
viper.SetDefault("gateway.image_concurrency.enabled", false)
viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
@ -2523,6 +2557,12 @@ func (c *Config) Validate() error {
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
if c.Gateway.ResponseHeaderTimeout < 0 {
return fmt.Errorf("gateway.response_header_timeout must be non-negative")
}
if c.Gateway.OpenAIResponseHeaderTimeout < 0 {
return fmt.Errorf("gateway.openai_response_header_timeout must be non-negative")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
@ -2697,6 +2737,15 @@ func (c *Config) Validate() error {
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
}
if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 {
return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative")
}
if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 {
return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative")
}
if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 {
return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||

View File

@ -163,6 +163,41 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
}
}
func TestLoadDefaultOpenAIHTTP2Enabled(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.Gateway.OpenAIHTTP2.Enabled)
require.True(t, cfg.Gateway.OpenAIHTTP2.AllowProxyFallbackToHTTP1)
}
func TestLoadOpenAIHTTP2DisabledFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_OPENAI_HTTP2_ENABLED", "false")
cfg, err := Load()
require.NoError(t, err)
require.False(t, cfg.Gateway.OpenAIHTTP2.Enabled)
}
func TestLoadDefaultOpenAIResponseHeaderTimeoutUnlimited(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
require.Equal(t, 0, cfg.Gateway.OpenAIResponseHeaderTimeout)
}
func TestLoadOpenAIResponseHeaderTimeoutFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT", "1800")
cfg, err := Load()
require.NoError(t, err)
require.Equal(t, 1800, cfg.Gateway.OpenAIResponseHeaderTimeout)
}
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
@ -1220,6 +1255,16 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
wantErr: "gateway.max_body_size",
},
{
name: "gateway response header timeout",
mutate: func(c *Config) { c.Gateway.ResponseHeaderTimeout = -1 },
wantErr: "gateway.response_header_timeout",
},
{
name: "gateway openai response header timeout",
mutate: func(c *Config) { c.Gateway.OpenAIResponseHeaderTimeout = -1 },
wantErr: "gateway.openai_response_header_timeout",
},
{
name: "gateway max idle conns",
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
@ -1275,6 +1320,21 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
},
{
name: "gateway openai http2 fallback threshold",
mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackErrorThreshold = -1 },
wantErr: "gateway.openai_http2.fallback_error_threshold",
},
{
name: "gateway openai http2 fallback window",
mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackWindowSeconds = -1 },
wantErr: "gateway.openai_http2.fallback_window_seconds",
},
{
name: "gateway openai http2 fallback ttl",
mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackTTLSeconds = -1 },
wantErr: "gateway.openai_http2.fallback_ttl_seconds",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },

View File

@ -0,0 +1,7 @@
package domain
// GroupModelsListConfig controls the optional custom /v1/models response list.
type GroupModelsListConfig struct {
Enabled bool `json:"enabled"`
Models []string `json:"models,omitempty"`
}

View File

@ -982,6 +982,100 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
}
// ApplyOAuthCredentialsRequest is the payload for persisting re-authorized OAuth credentials.
type ApplyOAuthCredentialsRequest struct {
Type string `json:"type" binding:"required,oneof=oauth setup-token"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
}
// ApplyOAuthCredentials 将"重新授权"得到的新凭据原子落库。
// POST /api/v1/admin/accounts/:id/apply-oauth-credentials
//
// 与通用 PUT /:id (Update) 接口的关键区别:
// - 仅接收 type / credentials / extra 三个字段(不接受 concurrency / rpm / quota_* 等可能误传的字段)
// - Extra 走 UpdateAccountExtra(JSONB key 级合并)**绝不**全量覆盖;
// 避免 base_rpm / window_cost_limit / max_sessions / quota_* / privacy_mode
// 等持久化配置在重新授权后丢失
// - 内置 ClearError + InvalidateToken避免前端额外两次调用
// 并修复旧路径未失效 token 缓存导致重新授权后立即 401 的隐性 bug
//
// 与 /refresh 的区别:/refresh 用现有 refresh_token 换 access_token无用户交互
// 本接口承接前端完成完整 OAuth 流程后的落库步骤。
func (h *AccountHandler) ApplyOAuthCredentials(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
var req ApplyOAuthCredentialsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
ctx := c.Request.Context()
// 预检查账号存在 + OAuth 类型(与 Refresh handler 语义一致,提供更友好的错误信息)。
existing, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
if !existing.IsOAuth() {
response.ErrorFrom(c, infraerrors.BadRequest("NOT_OAUTH", "cannot apply oauth credentials to non-OAuth account"))
return
}
updatedAccount, err := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
Type: req.Type,
Credentials: req.Credentials,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
// 增量合并 ExtraJSONB key 级 merge绝不覆盖 base_rpm / window_cost_limit /
// max_sessions / quota_* / privacy_mode 等持久化键)。
// best-effort失败仅记日志下方 ClearAccountError 会从 DB 重新读取最新 account
// 因此响应里的 extra 始终以 DB 为准——这里不需要手动维护内存快照。
if len(req.Extra) > 0 {
if extraErr := h.adminService.UpdateAccountExtra(ctx, accountID, req.Extra); extraErr != nil {
extraKeys := make([]string, 0, len(req.Extra))
for k := range req.Extra {
extraKeys = append(extraKeys, k)
}
slog.Error("apply_oauth_credentials.update_extra_failed",
"account_id", accountID,
"extra_keys", extraKeys,
"err", extraErr,
)
}
}
if cleared, clearErr := h.adminService.ClearAccountError(ctx, accountID); clearErr != nil {
slog.Warn("apply_oauth_credentials.clear_error_failed",
"account_id", accountID,
"err", clearErr,
)
} else if cleared != nil {
updatedAccount = cleared
}
if h.tokenCacheInvalidator != nil && updatedAccount.IsOAuth() {
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
slog.Warn("apply_oauth_credentials.invalidate_token_failed",
"account_id", accountID,
"err", invalidateErr,
)
}
}
response.Success(c, h.buildAccountResponseWithRuntime(ctx, updatedAccount))
}
// GetStats handles getting account statistics
// GET /api/v1/admin/accounts/:id/stats
func (h *AccountHandler) GetStats(c *gin.Context) {

View File

@ -0,0 +1,52 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountListRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.GET("/api/v1/admin/accounts", handler.List)
return router, adminSvc
}
func TestAccountHandlerListIncludesCreatedAt(t *testing.T) {
router, adminSvc := setupAccountListRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts?page=1&page_size=20&sort_by=created_at&sort_order=desc", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "created_at", adminSvc.lastListAccounts.sortBy)
var payload struct {
Data struct {
Items []struct {
ID int64 `json:"id"`
CreatedAt string `json:"created_at"`
} `json:"items"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
require.Len(t, payload.Data.Items, 1)
createdAt := payload.Data.Items[0].CreatedAt
require.NotEmpty(t, createdAt)
require.True(t, strings.HasSuffix(createdAt, "Z"), "created_at should be serialized as UTC")
parsed, err := time.Parse(time.RFC3339Nano, createdAt)
require.NoError(t, err)
_, offset := parsed.Zone()
require.Equal(t, 0, offset)
}

View File

@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil)
userHandler := NewUserHandler(adminSvc, nil, nil, nil)
groupHandler := NewGroupHandler(adminSvc, nil, nil)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)
@ -33,6 +33,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/groups", groupHandler.List)
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
router.GET("/api/v1/admin/groups/:id/models-list-candidates", groupHandler.GetModelsListCandidates)
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
router.POST("/api/v1/admin/groups", groupHandler.Create)
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
@ -177,6 +178,12 @@ func TestGroupHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/0/models-list-candidates?platform=openai", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Contains(t, rec.Body.String(), "gpt-5.5")
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))

View File

@ -24,6 +24,7 @@ type stubAdminService struct {
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
getUserErr error
createAccountErr error
updateAccountErr error
bulkUpdateAccountErr error
@ -147,6 +148,9 @@ func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, fi
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
if s.getUserErr != nil {
return nil, s.getUserErr
}
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil
@ -261,6 +265,13 @@ func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Gro
return &group, nil
}
func (s *stubAdminService) GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) {
if platform == service.PlatformOpenAI {
return []string{"gpt-5.5", "gpt-5.4"}, nil
}
return []string{"claude-sonnet-4-6"}, nil
}
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
return &group, nil
@ -345,6 +356,10 @@ func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *s
return &account, nil
}
func (s *stubAdminService) UpdateAccountExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
return nil
}

View File

@ -34,6 +34,7 @@ type contentModerationConfigRequest struct {
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
Thresholds *map[string]float64 `json:"thresholds"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
@ -94,6 +95,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
AllGroups: req.AllGroups,
GroupIDs: req.GroupIDs,
RecordNonHits: req.RecordNonHits,
Thresholds: req.Thresholds,
WorkerCount: req.WorkerCount,
QueueSize: req.QueueSize,
BlockStatus: req.BlockStatus,

View File

@ -113,6 +113,7 @@ type CreateGroupRequest struct {
RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
ModelsListConfig service.GroupModelsListConfig `json:"models_list_config"`
// 分组 RPM 上限0 = 不限制)
RPMLimit int `json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
@ -153,6 +154,7 @@ type UpdateGroupRequest struct {
RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
ModelsListConfig *service.GroupModelsListConfig `json:"models_list_config"`
// 分组 RPM 上限0 = 不限制nil 表示未提供不改动
RPMLimit *int `json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
@ -238,6 +240,28 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// GetModelsListCandidates handles getting candidate model IDs for custom /v1/models list.
// GET /api/v1/admin/groups/:id/models-list-candidates
func (h *GroupHandler) GetModelsListCandidates(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || groupID < 0 {
response.BadRequest(c, "Invalid group ID")
return
}
models, err := h.adminService.GetGroupModelsListCandidates(
c.Request.Context(),
groupID,
c.Query("platform"),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"models": models})
}
// Create handles creating a new group
// POST /api/v1/admin/groups
func (h *GroupHandler) Create(c *gin.Context) {
@ -275,6 +299,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
ModelsListConfig: req.ModelsListConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
@ -330,6 +355,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
ModelsListConfig: req.ModelsListConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})

View File

@ -305,6 +305,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
// Default platform quotasJSON map
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
slog.Error("default_platform_quotas_get_failed", "error", err)
} else {
payload.DefaultPlatformQuotas = platformQuotas
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
@ -637,6 +644,18 @@ type UpdateSettingsRequest struct {
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
// 系统全局 platform quota 默认值整体替换语义nil = 不修改non-nil = 整体覆盖)。
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas"`
// auth-source 层 platform quota 覆盖override 语义nil = 不修改non-nil = 整体覆盖该 source 的 quota 配置)。
AuthSourceEmailPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_email_platform_quotas"`
AuthSourceLinuxDoPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_linuxdo_platform_quotas"`
AuthSourceOIDCPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_oidc_platform_quotas"`
AuthSourceWeChatPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_wechat_platform_quotas"`
AuthSourceGitHubPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_github_platform_quotas"`
AuthSourceGooglePlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_google_platform_quotas"`
AuthSourceDingTalkPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_dingtalk_platform_quotas"`
}
// UpdateSettings 更新系统设置
@ -1438,6 +1457,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
// 系统全局 platform quota 默认值(整体替换语义)
DefaultPlatformQuotas: req.DefaultPlatformQuotas,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
@ -1731,6 +1753,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}(),
}
// req.AuthSourceXxxPlatformQuotas 为 nil 表示本次请求未包含该 source 的 quota 配置(保留 previousAuthSourceDefaults 中的值);
// non-nil含 empty map表示整体覆盖empty map = 清空该 source 的所有 quota 配置。
authSourceDefaults := &service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
@ -1738,6 +1762,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceEmailPlatformQuotas, previousAuthSourceDefaults.Email.PlatformQuotas),
},
LinuxDo: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
@ -1745,6 +1770,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceLinuxDoPlatformQuotas, previousAuthSourceDefaults.LinuxDo.PlatformQuotas),
},
OIDC: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
@ -1752,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceOIDCPlatformQuotas, previousAuthSourceDefaults.OIDC.PlatformQuotas),
},
WeChat: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
@ -1759,6 +1786,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceWeChatPlatformQuotas, previousAuthSourceDefaults.WeChat.PlatformQuotas),
},
GitHub: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
@ -1766,6 +1794,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGitHubPlatformQuotas, previousAuthSourceDefaults.GitHub.PlatformQuotas),
},
Google: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
@ -1773,6 +1802,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGooglePlatformQuotas, previousAuthSourceDefaults.Google.PlatformQuotas),
},
DingTalk: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
@ -1780,6 +1810,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceDingTalkPlatformQuotas, previousAuthSourceDefaults.DingTalk.PlatformQuotas),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
@ -2047,6 +2078,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
// Default platform quotasJSON map—— 与 GetSettings 一致,避免保存后响应缺失该字段
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
slog.Error("default_platform_quotas_get_failed", "error", err)
} else {
payload.DefaultPlatformQuotas = platformQuotas
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
@ -2511,6 +2549,10 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.RiskControlEnabled != after.RiskControlEnabled {
changed = append(changed, "risk_control_enabled")
}
// Default platform quotasJSON map整体比较
if !equalPlatformQuotaSettings(before.DefaultPlatformQuotas, after.DefaultPlatformQuotas) {
changed = append(changed, service.SettingKeyDefaultPlatformQuotas)
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
@ -2554,6 +2596,10 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
}
// Platform quotas diff整体替换语义发单个 JSON key。
if !equalPlatformQuotaSettings(field.before.PlatformQuotas, field.after.PlatformQuotas) {
changed = append(changed, service.SettingKeyAuthSourcePlatformQuotas(field.name))
}
}
if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
changed = append(changed, "force_email_on_third_party_signup")
@ -2621,6 +2667,17 @@ func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting,
return result
}
// platformQuotasValueOrDefault 处理 auth-source platform quota 的 nil 语义:
// nil = 请求未包含该字段(保留 fallbacknon-nil含 empty map= 整体覆盖。
// 注意JSON null 与字段省略等价——两者均反序列化为 nil map因此都保留旧值
// 若要清空某 source 的所有 quota 配置,须显式发空对象 {}。
func platformQuotasValueOrDefault(value, fallback map[string]*service.DefaultPlatformQuotaSetting) map[string]*service.DefaultPlatformQuotaSetting {
if value == nil {
return fallback
}
return value
}
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
data := make(map[string]any)
raw, err := json.Marshal(settings)
@ -2666,6 +2723,13 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
data["auth_source_default_email_platform_quotas"] = authSourceDefaults.Email.PlatformQuotas
data["auth_source_default_linuxdo_platform_quotas"] = authSourceDefaults.LinuxDo.PlatformQuotas
data["auth_source_default_oidc_platform_quotas"] = authSourceDefaults.OIDC.PlatformQuotas
data["auth_source_default_wechat_platform_quotas"] = authSourceDefaults.WeChat.PlatformQuotas
data["auth_source_default_github_platform_quotas"] = authSourceDefaults.GitHub.PlatformQuotas
data["auth_source_default_google_platform_quotas"] = authSourceDefaults.Google.PlatformQuotas
data["auth_source_default_dingtalk_platform_quotas"] = authSourceDefaults.DingTalk.PlatformQuotas
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
@ -3552,3 +3616,48 @@ func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo)
}
return placeholders
}
// equalNullableFloat compares two *float64 values treating nil as a distinct case.
func equalNullableFloat(a, b *float64) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
// slotOf returns the *float64 for the given window from a DefaultPlatformQuotaSetting.
func slotOf(s *service.DefaultPlatformQuotaSetting, win string) *float64 {
if s == nil {
return nil
}
switch win {
case "daily":
return s.DailyLimitUSD
case "weekly":
return s.WeeklyLimitUSD
case "monthly":
return s.MonthlyLimitUSD
}
return nil
}
// equalPlatformQuotaSettings reports whether two platform-quota maps are identical across all 12 slots.
func equalPlatformQuotaSettings(before, after map[string]*service.DefaultPlatformQuotaSetting) bool {
for _, platform := range service.AllowedQuotaPlatforms {
b := before[platform]
a := after[platform]
if !equalNullableFloat(slotOf(b, "daily"), slotOf(a, "daily")) {
return false
}
if !equalNullableFloat(slotOf(b, "weekly"), slotOf(a, "weekly")) {
return false
}
if !equalNullableFloat(slotOf(b, "monthly"), slotOf(a, "monthly")) {
return false
}
}
return true
}

View File

@ -0,0 +1,188 @@
//go:build unit
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestDiffSettings_DetectsGlobalPlatformQuotaChange(t *testing.T) {
five := 5.0
ten := 10.0
before := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
after := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &ten},
},
}
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
found := false
for _, key := range changed {
if key == service.SettingKeyDefaultPlatformQuotas {
found = true
break
}
}
if !found {
t.Errorf("expected change detection for default platform quotas, got %v", changed)
}
}
func TestDiffSettings_NoChangeWhenEqual(t *testing.T) {
five := 5.0
before := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
after := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
for _, key := range changed {
if key == service.SettingKeyDefaultPlatformQuotas {
t.Error("equal values should not be detected as changed")
}
}
}
func TestEqualNullableFloat(t *testing.T) {
five := 5.0
five2 := 5.0
ten := 10.0
cases := []struct {
a, b *float64
want bool
}{
{nil, nil, true},
{&five, nil, false},
{nil, &five, false},
{&five, &five2, true},
{&five, &ten, false},
}
for _, c := range cases {
if got := equalNullableFloat(c.a, c.b); got != c.want {
t.Errorf("equalNullableFloat(%v, %v) = %v, want %v", c.a, c.b, got, c.want)
}
}
}
func TestEqualPlatformQuotaSettings_DetectsPerWindowChange(t *testing.T) {
five := 5.0
ten := 10.0
before := map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
}
after := map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &ten},
}
if equalPlatformQuotaSettings(before, after) {
t.Error("expected unequal")
}
}
func TestAppendAuthSourceDefaultChanges_DetectsPerWindow(t *testing.T) {
five := 5.0
ten := 10.0
before := &service.AuthSourceDefaultSettings{
LinuxDo: service.ProviderDefaultGrantSettings{
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
},
}
after := &service.AuthSourceDefaultSettings{
LinuxDo: service.ProviderDefaultGrantSettings{
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &ten},
},
},
}
changed := appendAuthSourceDefaultChanges([]string{}, before, after)
// 改动 B5整体替换语义审计 log 发单个 JSON key而非展开 84 个扁平 key。
key := service.SettingKeyAuthSourcePlatformQuotas("linuxdo")
found := false
for _, k := range changed {
if k == key {
found = true
break
}
}
if !found {
t.Errorf("expected %q in changed, got %v", key, changed)
}
}
// TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip 验证 Bug A 修复:
// PUT 发 auth_source_default_email_platform_quotasGET 能读回相同值(端到端往返)。
func TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
// PUT发 email platform quotaopenai monthly=20
putBody := map[string]any{
"auth_source_default_email_platform_quotas": map[string]any{
"openai": map[string]any{
"monthly": 20,
},
},
}
rawBody, err := json.Marshal(putBody)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证 DB 中写入了 JSON key
jsonKey := service.SettingKeyAuthSourcePlatformQuotas("email")
require.NotEmpty(t, repo.values[jsonKey], "expected JSON key to be written to DB")
// GET验证响应中 auth_source_default_email_platform_quotas.openai.monthly = 20
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
handler.GetSettings(c2)
require.Equal(t, http.StatusOK, rec2.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
emailPQ, ok := data["auth_source_default_email_platform_quotas"].(map[string]any)
require.True(t, ok, "expected auth_source_default_email_platform_quotas to be a map")
openaiPQ, ok := emailPQ["openai"].(map[string]any)
require.True(t, ok, "expected openai entry in email platform quotas")
monthly, ok := openaiPQ["monthly"].(float64)
require.True(t, ok, "expected monthly to be float64")
require.Equal(t, float64(20), monthly, "expected openai monthly=20")
}

View File

@ -2,10 +2,15 @@ package admin
import (
"context"
"errors"
"log/slog"
"math"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@ -20,15 +25,24 @@ type UserWithConcurrency struct {
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
concurrencyService *service.ConcurrencyService
adminService service.AdminService
concurrencyService *service.ConcurrencyService
userPlatformQuotaRepo service.UserPlatformQuotaRepository // T13 admin quota view
billingCache service.BillingCache // T17/T18 缓存失效PUT/POST 路径)
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
func NewUserHandler(
adminService service.AdminService,
concurrencyService *service.ConcurrencyService,
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
billingCache service.BillingCache,
) *UserHandler {
return &UserHandler{
adminService: adminService,
concurrencyService: concurrencyService,
adminService: adminService,
concurrencyService: concurrencyService,
userPlatformQuotaRepo: userPlatformQuotaRepo,
billingCache: billingCache,
}
}
@ -537,3 +551,294 @@ func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
}
response.Success(c, gin.H{"affected": affected})
}
// GetUserPlatformQuotas GET /admin/users/:id/platform-quotas
// admin 视角D14 lazy 归零 + 暴露 *_window_start 调试字段
func (h *UserHandler) GetUserPlatformQuotas(c *gin.Context) {
idStr := c.Param("id")
userID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid user id")
return
}
if h.userPlatformQuotaRepo == nil {
response.Success(c, map[string]any{"platform_quotas": []any{}})
return
}
// 校验用户存在:与 PUT/POST 路径一致,不存在返回 404 而非空数组(避免 admin 界面误判用户存在)。
if _, err := h.adminService.GetUser(c.Request.Context(), userID); err != nil {
response.ErrorFrom(c, err)
return
}
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
now := time.Now().UTC()
out := make([]map[string]any, 0, len(records))
for _, r := range records {
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, true)) // true = 暴露 window_start
}
response.Success(c, map[string]any{"platform_quotas": out})
}
// UpdateUserPlatformQuotasRequest is the body for PUT /admin/users/:id/platform-quotas.
type UpdateUserPlatformQuotasRequest struct {
Quotas []PlatformQuotaInput `json:"quotas" binding:"required"`
}
// PlatformQuotaInput 单平台限额输入limit 字段为 nil 表示不限制。
type PlatformQuotaInput struct {
Platform string `json:"platform" binding:"required"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
}
// platform 合法性由 service.IsAllowedQuotaPlatform / service.AllowedQuotaPlatforms 统一判断(单一源)。
// UpdateUserPlatformQuotas PUT /admin/users/:id/platform-quotas
// 全量替换该用户所有平台限额。
func (h *UserHandler) UpdateUserPlatformQuotas(c *gin.Context) {
if h.userPlatformQuotaRepo == nil {
response.Error(c, 503, "platform quota service not available")
return
}
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserPlatformQuotasRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.Quotas) > 4 {
response.BadRequest(c, "quotas length must be <= 4")
return
}
seen := make(map[string]struct{}, len(req.Quotas))
for _, q := range req.Quotas {
if !service.IsAllowedQuotaPlatform(q.Platform) {
response.BadRequest(c, "invalid platform: "+q.Platform)
return
}
if _, dup := seen[q.Platform]; dup {
response.BadRequest(c, "duplicate platform: "+q.Platform)
return
}
seen[q.Platform] = struct{}{}
// daily_limit_usd / weekly_limit_usd / monthly_limit_usd 的语义:
// nil / not set → 无限额(完全放行)
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
// > 0 → USD 限额上限
// 拦截 NaN / ±Inf客户端可发送超大数如 1e308 × 2使 JSON 反序列化得到 +Inf
// 进入 DB 后 cache check 中 usage >= limit 永不成立limit 等同失效。
for _, f := range []struct {
name string
val *float64
}{
{"daily_limit_usd", q.DailyLimitUSD},
{"weekly_limit_usd", q.WeeklyLimitUSD},
{"monthly_limit_usd", q.MonthlyLimitUSD},
} {
if f.val == nil {
continue
}
v := *f.val
if v < 0 {
response.BadRequest(c, f.name+" must be >= 0")
return
}
if math.IsNaN(v) || math.IsInf(v, 0) {
response.BadRequest(c, f.name+" must be a finite number")
return
}
}
}
records := make([]service.UserPlatformQuotaRecord, 0, len(req.Quotas))
for _, q := range req.Quotas {
records = append(records, service.UserPlatformQuotaRecord{
UserID: userID,
Platform: q.Platform,
DailyLimitUSD: q.DailyLimitUSD,
WeeklyLimitUSD: q.WeeklyLimitUSD,
MonthlyLimitUSD: q.MonthlyLimitUSD,
})
}
ctx := c.Request.Context()
// 校验用户是否存在,避免 FK 违反导致 500用户不存在时返回 404。
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
response.ErrorFrom(c, err)
return
}
// 在 UpsertForUser 之前抓取 before snapshot 用于审计 before/after 对比。
// ListByUser 失败不阻断主操作best-effort仅记录降级 warn。
beforeRecords, beforeErr := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
if beforeErr != nil {
slog.Warn("quota audit before snapshot failed", "user_id", userID, "err", beforeErr)
}
if err := h.userPlatformQuotaRepo.UpsertForUser(ctx, userID, records); err != nil {
response.ErrorFrom(c, err)
return
}
beforeByPlatform := make(map[string]service.UserPlatformQuotaRecord, len(beforeRecords))
for _, r := range beforeRecords {
beforeByPlatform[r.Platform] = r
}
afterPlatforms := make(map[string]struct{}, len(records))
for _, r := range records {
afterPlatforms[r.Platform] = struct{}{}
}
changes := make([]map[string]any, 0, len(records))
for _, r := range records {
entry := map[string]any{
"platform": r.Platform,
"daily_limit_usd": r.DailyLimitUSD,
"weekly_limit_usd": r.WeeklyLimitUSD,
"monthly_limit_usd": r.MonthlyLimitUSD,
}
if prev, ok := beforeByPlatform[r.Platform]; ok {
entry["before_daily_limit_usd"] = prev.DailyLimitUSD
entry["before_weekly_limit_usd"] = prev.WeeklyLimitUSD
entry["before_monthly_limit_usd"] = prev.MonthlyLimitUSD
}
changes = append(changes, entry)
}
// 补 removed 条目before 存在但 after 缺失 = 该平台被软删除。
// 缺少这条记录,审计消费方无法察觉"管理员把某平台从配额列表移除"的操作(合规盲区)。
for _, prev := range beforeRecords {
if _, kept := afterPlatforms[prev.Platform]; kept {
continue
}
changes = append(changes, map[string]any{
"platform": prev.Platform,
"removed": true,
"before_daily_limit_usd": prev.DailyLimitUSD,
"before_weekly_limit_usd": prev.WeeklyLimitUSD,
"before_monthly_limit_usd": prev.MonthlyLimitUSD,
})
}
// before_snapshot_available 让审计消费方能识别 changes 中是否带 before_* 字段;
// false 时所有 entry 都会缺失 before_*_limit_usd仅有 after 视图。
slog.Info("admin.quota_updated",
"actor_admin_id", getAdminIDFromContext(c),
"target_user_id", userID,
"platform_count", len(records),
"before_snapshot_available", beforeErr == nil,
"changes", changes)
// 失效 cache对全部允许的 platform 统一 invalidate。
// Trade-off精确失效仅 req 涉及平台 + 被软删平台)需 upsert 前额外 ListByUser
// 增加一次 DB 查询和逻辑复杂度。由于 AllowedQuotaPlatforms 只有 4 个元素,
// 全量 invalidate 的额外开销可接受,且能可靠覆盖软删除场景。
if h.billingCache != nil {
for _, p := range service.AllowedQuotaPlatforms {
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, p); err != nil {
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", p, "err", err)
}
}
}
// 返回最新状态
now := time.Now().UTC()
records2, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]map[string]any, 0, len(records2))
for i := range records2 {
out = append(out, quotaview.LazyZeroQuotaForResponse(records2[i], now, true))
}
response.Success(c, map[string]any{"platform_quotas": out})
}
// ResetUserPlatformQuotaWindowRequest is the body for POST /admin/users/:id/platform-quotas/reset.
type ResetUserPlatformQuotaWindowRequest struct {
Platform string `json:"platform" binding:"required"`
Window string `json:"window" binding:"required"`
}
var allowedWindowsForQuotaReset = map[string]struct{}{
"daily": {},
"weekly": {},
"monthly": {},
}
// ResetUserPlatformQuotaWindow POST /admin/users/:id/platform-quotas/reset
// 立即归零指定 (platform, window) 的用量并更新 window_start。
func (h *UserHandler) ResetUserPlatformQuotaWindow(c *gin.Context) {
if h.userPlatformQuotaRepo == nil {
response.Error(c, 503, "platform quota service not available")
return
}
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req ResetUserPlatformQuotaWindowRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !service.IsAllowedQuotaPlatform(req.Platform) {
response.BadRequest(c, "invalid platform: "+req.Platform)
return
}
if _, ok := allowedWindowsForQuotaReset[req.Window]; !ok {
response.BadRequest(c, "invalid window: "+req.Window)
return
}
ctx := c.Request.Context()
// 校验用户是否存在,避免对不存在的用户执行操作返回误导性的 500。
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
response.ErrorFrom(c, err)
return
}
now := time.Now().UTC()
if err := h.userPlatformQuotaRepo.ResetExpiredWindow(ctx, userID, req.Platform, req.Window, now); err != nil {
if errors.Is(err, service.ErrUserPlatformQuotaNotFound) {
response.NotFound(c, "user platform quota not found")
return
}
response.ErrorFrom(c, err)
return
}
slog.Info("admin.quota_window_reset",
"actor_admin_id", getAdminIDFromContext(c),
"target_user_id", userID,
"platform", req.Platform,
"window", req.Window)
if h.billingCache != nil {
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, req.Platform); err != nil {
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", req.Platform, "err", err)
}
}
records, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]map[string]any, 0, len(records))
for i := range records {
out = append(out, quotaview.LazyZeroQuotaForResponse(records[i], now, true))
}
response.Success(c, map[string]any{"platform_quotas": out})
}

View File

@ -35,7 +35,7 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
UpdatedAt: lastLoginAt,
},
}
handler := NewUserHandler(adminSvc, nil)
handler := NewUserHandler(adminSvc, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -89,7 +89,7 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
UpdatedAt: lastLoginAt,
},
}
handler := NewUserHandler(adminSvc, nil)
handler := NewUserHandler(adminSvc, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)

View File

@ -0,0 +1,301 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// upsertCapturingQuotaRepo 实现 service.UserPlatformQuotaRepository捕获 UpsertForUser 调用。
type upsertCapturingQuotaRepo struct {
service.UserPlatformQuotaRepository
listRecords []service.UserPlatformQuotaRecord
listErr error
upsertCalls []upsertCall
upsertErr error
resetCalls []resetCall
resetErr error
}
type upsertCall struct {
userID int64
records []service.UserPlatformQuotaRecord
}
type resetCall struct {
userID int64
platform string
window string
newStart time.Time
}
func (r *upsertCapturingQuotaRepo) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
return r.listRecords, r.listErr
}
func (r *upsertCapturingQuotaRepo) UpsertForUser(_ context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
cloned := make([]service.UserPlatformQuotaRecord, len(records))
copy(cloned, records)
r.upsertCalls = append(r.upsertCalls, upsertCall{userID: userID, records: cloned})
return r.upsertErr
}
func (r *upsertCapturingQuotaRepo) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error {
r.resetCalls = append(r.resetCalls, resetCall{userID, platform, window, newStart})
return r.resetErr
}
// billingCacheStub 实现 service.BillingCache 中本测试关心的 Delete 方法;其他方法 panic。
type billingCacheStub struct {
service.BillingCache
deleteCalls []deleteCall
deleteErr error
}
type deleteCall struct {
userID int64
platform string
}
func (b *billingCacheStub) DeleteUserPlatformQuotaCache(_ context.Context, userID int64, platform string) error {
b.deleteCalls = append(b.deleteCalls, deleteCall{userID, platform})
return b.deleteErr
}
func buildTestHandler(repo service.UserPlatformQuotaRepository, cache service.BillingCache) *UserHandler {
return &UserHandler{
userPlatformQuotaRepo: repo,
billingCache: cache,
adminService: newStubAdminService(),
}
}
func putReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req, _ := http.NewRequest(http.MethodPut, "/", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
c.Request = req
c.Params = []gin.Param{{Key: "id", Value: "42"}}
return c, w
}
func TestUpdateUserPlatformQuotas_Success(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"quotas":[
{"platform":"anthropic","daily_limit_usd":10.0,"weekly_limit_usd":null,"monthly_limit_usd":100.0},
{"platform":"openai","daily_limit_usd":null,"weekly_limit_usd":null,"monthly_limit_usd":null}
]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if len(repo.upsertCalls) != 1 {
t.Fatalf("UpsertForUser should be called once, got %d", len(repo.upsertCalls))
}
if repo.upsertCalls[0].userID != 42 || len(repo.upsertCalls[0].records) != 2 {
t.Errorf("unexpected upsert call: %+v", repo.upsertCalls[0])
}
// 缓存失效:请求中 2 个 platform + 软删除的 2 个 platformgemini, antigravity= 4 次
if len(cache.deleteCalls) != 4 {
t.Errorf("expected 4 cache delete calls, got %d: %+v", len(cache.deleteCalls), cache.deleteCalls)
}
}
func TestUpdateUserPlatformQuotas_RejectsDuplicatePlatform(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[
{"platform":"anthropic","daily_limit_usd":1},
{"platform":"anthropic","daily_limit_usd":2}
]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_RejectsInvalidPlatform(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[{"platform":"unknown","daily_limit_usd":1}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_RejectsNegativeLimit(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":-1}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_RejectsTooManyEntries(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[
{"platform":"anthropic"},{"platform":"openai"},{"platform":"gemini"},{"platform":"antigravity"},{"platform":"anthropic"}
]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_ReturnsLatestState(t *testing.T) {
repo := &upsertCapturingQuotaRepo{
listRecords: []service.UserPlatformQuotaRecord{
{UserID: 42, Platform: "anthropic"},
},
}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if !strings.Contains(w.Body.String(), `"platform_quotas"`) {
t.Errorf("response should contain platform_quotas array: %s", w.Body.String())
}
}
// ───────── T4: Reset 测试 ─────────
func postReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
c.Request = req
c.Params = []gin.Param{{Key: "id", Value: "42"}}
return c, w
}
func TestResetUserPlatformQuotaWindow_Success(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"platform":"anthropic","window":"daily"}`
c, w := postReq(t, body)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if len(repo.resetCalls) != 1 {
t.Fatalf("ResetExpiredWindow should be called once, got %d", len(repo.resetCalls))
}
if repo.resetCalls[0].userID != 42 ||
repo.resetCalls[0].platform != "anthropic" ||
repo.resetCalls[0].window != "daily" {
t.Errorf("unexpected reset call: %+v", repo.resetCalls[0])
}
if len(cache.deleteCalls) != 1 ||
cache.deleteCalls[0].userID != 42 ||
cache.deleteCalls[0].platform != "anthropic" {
t.Errorf("expected 1 cache delete for anthropic, got %+v", cache.deleteCalls)
}
}
func TestResetUserPlatformQuotaWindow_RejectsInvalidWindow(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
c, w := postReq(t, `{"platform":"anthropic","window":"yearly"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestResetUserPlatformQuotaWindow_RejectsInvalidPlatform(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
c, w := postReq(t, `{"platform":"unknown","window":"daily"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestResetUserPlatformQuotaWindow_NotFound(t *testing.T) {
// handler 检查 service.ErrUserPlatformQuotaNotFound由 adapter 包装而来)
repo := &upsertCapturingQuotaRepo{resetErr: service.ErrUserPlatformQuotaNotFound}
h := buildTestHandler(repo, &billingCacheStub{})
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_JSONErrorOnRepoFailure(t *testing.T) {
repo := &upsertCapturingQuotaRepo{upsertErr: errors.New("db down")}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code < 500 {
t.Errorf("expected 5xx, got %d", w.Code)
}
// 返回 JSON 错误响应
var body2 map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &body2); err != nil {
t.Errorf("expected JSON error body, got: %s", w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_UserNotFound(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
adminSvc := newStubAdminService()
adminSvc.getUserErr = service.ErrUserNotFound
h := &UserHandler{
userPlatformQuotaRepo: repo,
billingCache: cache,
adminService: adminSvc,
}
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
}
}
func TestResetUserPlatformQuotaWindow_UserNotFound(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
adminSvc := newStubAdminService()
adminSvc.getUserErr = service.ErrUserNotFound
h := &UserHandler{
userPlatformQuotaRepo: repo,
billingCache: cache,
adminService: adminSvc,
}
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
}
}

View File

@ -0,0 +1,124 @@
//go:build unit
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type fakeQuotaRepoForAdmin struct {
service.UserPlatformQuotaRepository
records []service.UserPlatformQuotaRecord
err error
}
func (f *fakeQuotaRepoForAdmin) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
return f.records, f.err
}
func newAdminQuotaTestContext(w *httptest.ResponseRecorder) *gin.Context {
c, _ := gin.CreateTestContext(w)
req, _ := http.NewRequest(http.MethodGet, "/", nil)
c.Request = req
return c
}
func TestAdminGetUserPlatformQuotas_IncludesWindowStart(t *testing.T) {
start := time.Now().Add(-1 * time.Hour)
repo := &fakeQuotaRepoForAdmin{records: []service.UserPlatformQuotaRecord{{
UserID: 99, Platform: "anthropic",
DailyUsageUSD: 1.0, DailyWindowStart: &start,
}}}
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "99"}}
h.GetUserPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), `"daily_window_start"`) {
t.Errorf("admin response missing daily_window_start, got: %s", w.Body.String())
}
}
func TestAdminGetUserPlatformQuotas_InvalidIDReturns400(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: &fakeQuotaRepoForAdmin{}}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "abc"}}
h.GetUserPlatformQuotas(c)
if w.Code < 400 || w.Code >= 500 {
t.Errorf("invalid id should yield 4xx, got %d", w.Code)
}
}
func TestAdminGetUserPlatformQuotas_EmptyReturnsEmptyArray(t *testing.T) {
repo := &fakeQuotaRepoForAdmin{records: nil}
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "99"}}
h.GetUserPlatformQuotas(c)
if w.Code != 200 {
t.Errorf("empty list should be 200, got %d", w.Code)
}
var body map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("response is not valid JSON: %v", err)
}
data, ok := body["data"].(map[string]any)
if !ok {
t.Fatalf("response missing data object: %v", body)
}
quotas, ok := data["platform_quotas"].([]any)
if !ok {
t.Fatalf("data.platform_quotas missing or wrong type: %v", data)
}
if len(quotas) != 0 {
t.Errorf("expected empty platform_quotas, got %d entries: %v", len(quotas), quotas)
}
}
func TestAdminGetUserPlatformQuotas_NilRepoReturnsEmpty(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: nil}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "1"}}
h.GetUserPlatformQuotas(c)
if w.Code != 200 {
t.Errorf("nil repo should return 200 empty, got %d", w.Code)
}
}
// TestAdminGetUserPlatformQuotas_UserNotFoundReturns404 验证 GET 在用户不存在时返回 404
// (与 PUT / POST reset 端点行为一致review fix原实现返回空数组会让 admin 界面误判用户存在)
func TestAdminGetUserPlatformQuotas_UserNotFoundReturns404(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.getUserErr = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
repo := &fakeQuotaRepoForAdmin{records: nil}
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: adminSvc}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "999"}}
h.GetUserPlatformQuotas(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 for non-existent user, got %d: %s", w.Code, w.Body.String())
}
}

View File

@ -2233,6 +2233,7 @@ CREATE TABLE IF NOT EXISTS user_affiliates (
nil,
options.defaultSubAssigner,
affiliateService,
nil,
)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
var totpSvc *service.TotpService

View File

@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
handler := &AuthHandler{authService: authService}
recorder := httptest.NewRecorder()

View File

@ -1400,6 +1400,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
nil,
nil,
nil,
nil,
)
return &AuthHandler{

View File

@ -147,6 +147,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
ModelsListConfig: g.ModelsListConfig,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
ActiveAccountCount: g.ActiveAccountCount,

View File

@ -3,6 +3,8 @@ package dto
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// CustomMenuItem represents a user-configured custom menu entry.
@ -246,6 +248,9 @@ type SystemSettings struct {
// OpenAI fast/flex policy
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
// 系统全局默认平台配额key = platformnil/缺省 = 不限制)
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas,omitempty"`
}
type DefaultSubscriptionSetting struct {

View File

@ -138,6 +138,7 @@ type AdminGroup struct {
// OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`

View File

@ -17,6 +17,7 @@ import (
const (
EndpointMessages = "/v1/messages"
EndpointChatCompletions = "/v1/chat/completions"
EndpointEmbeddings = "/v1/embeddings"
EndpointResponses = "/v1/responses"
EndpointImagesGenerations = "/v1/images/generations"
EndpointImagesEdits = "/v1/images/edits"
@ -42,6 +43,8 @@ const (
func NormalizeInboundEndpoint(path string) string {
path = strings.TrimSpace(path)
switch {
case strings.Contains(path, EndpointEmbeddings):
return EndpointEmbeddings
case strings.Contains(path, EndpointChatCompletions):
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
@ -75,7 +78,7 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
switch platform {
case service.PlatformOpenAI:
if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
return inbound
}
// OpenAI forwards everything to the Responses API.

View File

@ -24,6 +24,7 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
// Direct canonical paths.
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/embeddings", EndpointEmbeddings},
{"/v1/responses", EndpointResponses},
{"/v1/images/generations", EndpointImagesGenerations},
{"/v1/images/edits", EndpointImagesEdits},
@ -77,6 +78,7 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
{"openai embeddings", EndpointEmbeddings, "/v1/embeddings", service.PlatformOpenAI, EndpointEmbeddings},
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"strconv"
"strings"
@ -253,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -533,10 +534,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,
@ -825,6 +828,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Beta policy block: return 400 immediately, no failover
var betaBlockedErr *service.BetaBlockedError
if errors.As(err, &betaBlockedErr) {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
return
}
@ -855,7 +859,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil, service.PlatformFromAPIKey(fallbackAPIKey)); err != nil {
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
@ -960,10 +964,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
QuotaPlatform: quotaPlatform,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
@ -1015,22 +1021,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
// Get available models from account configurations for the selected group platform.
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
if apiKey != nil && apiKey.Group != nil && apiKey.Group.CustomModelsListEnabled() {
availableModels = filterModelsByCustomList(availableModels, defaultModelIDsForPlatform(platform), apiKey.Group.ModelsListConfig.Models)
writeCustomModelsList(c, platform, availableModels)
return
}
if len(availableModels) > 0 {
// Build model list from whitelist
models := make([]claude.Model, 0, len(availableModels))
for _, modelID := range availableModels {
models = append(models, claude.Model{
ID: modelID,
Type: "model",
DisplayName: modelID,
CreatedAt: "2024-01-01T00:00:00Z",
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
writeModelsList(c, availableModels)
return
}
@ -1057,6 +1055,134 @@ func (h *GatewayHandler) Models(c *gin.Context) {
})
}
func writeModelsList(c *gin.Context, modelIDs []string) {
models := make([]claude.Model, 0, len(modelIDs))
for _, modelID := range modelIDs {
models = append(models, claude.Model{
ID: modelID,
Type: "model",
DisplayName: modelID,
CreatedAt: "2024-01-01T00:00:00Z",
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
}
func writeCustomModelsList(c *gin.Context, platform string, modelIDs []string) {
if platform == service.PlatformOpenAI {
writeOpenAIModelsList(c, modelIDs)
return
}
writeModelsList(c, modelIDs)
}
func writeOpenAIModelsList(c *gin.Context, modelIDs []string) {
defaultsByID := make(map[string]openai.Model, len(openai.DefaultModels))
for _, model := range openai.DefaultModels {
defaultsByID[model.ID] = model
}
models := make([]openai.Model, 0, len(modelIDs))
for _, modelID := range modelIDs {
if model, ok := defaultsByID[modelID]; ok {
models = append(models, model)
continue
}
models = append(models, openai.Model{
ID: modelID,
Object: "model",
Created: 1704067200,
OwnedBy: "openai",
Type: "model",
DisplayName: modelID,
})
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": models,
})
}
func filterModelsByCustomList(availableModels, fallbackModels, selectedModels []string) []string {
if len(selectedModels) == 0 {
return availableModels
}
source := availableModels
if len(source) == 0 {
source = fallbackModels
}
if len(source) == 0 {
return nil
}
allowed := make([]string, 0, len(source))
for _, model := range source {
model = strings.TrimSpace(model)
if model != "" {
allowed = append(allowed, model)
}
}
seen := make(map[string]struct{}, len(selectedModels))
filtered := make([]string, 0, len(selectedModels))
for _, model := range selectedModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if !customModelsListAllowsModel(allowed, model) {
continue
}
if _, ok := seen[model]; ok {
continue
}
seen[model] = struct{}{}
filtered = append(filtered, model)
}
return filtered
}
func customModelsListAllowsModel(availablePatterns []string, model string) bool {
for _, pattern := range availablePatterns {
if pattern == model {
return true
}
if strings.HasSuffix(pattern, "*") && strings.HasPrefix(model, strings.TrimSuffix(pattern, "*")) {
return true
}
}
return false
}
func defaultModelIDsForPlatform(platform string) []string {
switch platform {
case service.PlatformOpenAI:
return openai.DefaultModelIDs()
case service.PlatformGemini:
ids := make([]string, 0, len(geminicli.DefaultModels))
for _, model := range geminicli.DefaultModels {
ids = append(ids, model.ID)
}
return ids
case service.PlatformAntigravity:
models := antigravity.DefaultModels()
ids := make([]string, 0, len(models))
for _, model := range models {
ids = append(ids, model.ID)
}
return ids
default:
ids := make([]string, 0, len(claude.DefaultModels))
for _, model := range claude.DefaultModels {
ids = append(ids, model.ID)
}
return ids
}
}
// AntigravityModels 返回 Antigravity 支持的全部模型
// GET /antigravity/models
func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
@ -1502,6 +1628,14 @@ func (h *GatewayHandler) sendFailoverKeepalivePing(c *gin.Context, streamStarted
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// /v1/responses 的严格 SDKCodex CLI要求终止事件必须属于
// response.completed/failed/incomplete/cancelled 集合。
// Anthropic-backed Responses 路径同样会因为通用 error 帧被拒。
if inboundIsResponses(c) {
if writeResponsesFailedSSE(c, errType, message) {
return
}
}
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
@ -1520,10 +1654,16 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
// Writer 已被写过时ping 已 flush走 streamStarted 分支,
// 让 handleStreamingAwareError 通过 SSE 发协议合规的终止事件,
// 否则下游收到的就是 silent EOF。
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
if c == nil || c.Writer == nil {
return false
}
if c.Writer.Written() {
streamStarted = true
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
@ -1650,7 +1790,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
@ -1898,6 +2038,36 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c.JSON(http.StatusOK, response)
}
// extractQuotaResetSeconds 从 quota 错误的 metadata 中提取 window_resets_at 并计算
// 距重置剩余秒数。fallback 路径必须返回 ≥1 秒,避免客户端立即重试无限循环。
func extractQuotaResetSeconds(err error) int {
const fallback = 60
appErr := pkgerrors.FromError(err)
if appErr == nil {
return fallback
}
raw, ok := appErr.Metadata["window_resets_at"]
if !ok || raw == "" {
return fallback
}
resetAt, parseErr := time.Parse(time.RFC3339, raw)
if parseErr != nil {
logger.L().With(
zap.String("component", "handler.gateway.billing"),
zap.String("raw", raw),
zap.Error(parseErr),
).Warn("quota.invalid_window_resets_at_format")
return fallback
}
secs := time.Until(resetAt).Seconds()
if secs <= 0 {
// reset 时间已过cache 与 DB 应该正在自愈,返回 fallback 让客户端按常规节奏退避,
// 避免返回 1 秒导致客户端立即重试仍触发限额的退避循环。
return fallback
}
return int(math.Ceil(secs))
}
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
@ -1925,6 +2095,14 @@ func billingErrorDetails(err error) (status int, code, message string, retryAfte
retrySeconds := 60 - int(time.Now().Unix()%60)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
}
if errors.Is(err, service.ErrUserPlatformDailyQuotaExhausted) ||
errors.Is(err, service.ErrUserPlatformWeeklyQuotaExhausted) ||
errors.Is(err, service.ErrUserPlatformMonthlyQuotaExhausted) {
// 与 RPM 超限一致映射 429 + Retry-After让 SDK 自动退避(而非 403 直接失败)。
// 错误码用 rate_limit_exceeded 与 OpenAI 兼容客户端一致;细分类型由 ErrCode + window_resets_at metadata 区分。
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, extractQuotaResetSeconds(err)
}
msg := pkgerrors.Message(err)
if msg == "" {
logger.L().With(

View File

@ -1,8 +1,10 @@
package handler
import (
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
@ -52,3 +54,75 @@ func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
require.Equal(t, "billing_error", code)
require.NotEmpty(t, msg)
}
func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) {
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339),
})
got := extractQuotaResetSeconds(err)
if got < 10 || got > 11 {
t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got)
}
}
func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) {
if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 {
t.Errorf("T20: got %d, want 60 fallback", got)
}
}
func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) {
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": "not-a-time",
})
if got := extractQuotaResetSeconds(err); got != 60 {
t.Errorf("T21: got %d, want 60 fallback", got)
}
}
func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) {
// 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s
// 1 秒会导致客户端立即重试仍触发限额的退避循环;
// 60s 让客户端按常规节奏退避cache/DB 自愈期间不会反复打抖。
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339),
})
if got := extractQuotaResetSeconds(err); got != 60 {
t.Errorf("T22: got %d, want 60 (fallback on past reset)", got)
}
}
func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) {
// quota 超限映射 429 + Retry-AfterRFC 6585 / 与 RPM 一致),
// 让 SDKOpenAI 兼容客户端等)能按 Retry-After 自动退避。
// 旧实现用 403 导致客户端不退避直接报错。
// 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。
cases := []struct {
name string
err error
}{
{"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
{"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
{"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
status, code, _, retryAfter := billingErrorDetails(tc.err)
if status != http.StatusTooManyRequests {
t.Errorf("status = %d, want 429", status)
}
if code != "rate_limit_exceeded" {
t.Errorf("code = %q, want rate_limit_exceeded", code)
}
if retryAfter < 3599 || retryAfter > 3601 {
t.Errorf("retryAfter = %d, want ~3600", retryAfter)
}
})
}
}

View File

@ -140,7 +140,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -291,9 +291,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,

View File

@ -33,7 +33,9 @@ func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testi
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
// Writer 已写后 ensureForwardErrorResponse 必须把错误以 SSE 形式追加,
// 而不是 silent EOF。非 /responses 路径走 legacy data:{"type":"error"} 分支。
func TestGatewayEnsureForwardErrorResponse_AppendsSSEAfterWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@ -43,7 +45,27 @@ func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *tes
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.True(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
assert.Contains(t, w.Body.String(), "already written")
assert.Contains(t, w.Body.String(), `data: {"type":"error"`)
}
// case B 回归Anthropic-backed /responsesWriter 已被写过时
// ensureForwardErrorResponse 仍要发 response.failed。
func TestGatewayEnsureForwardErrorResponse_ResponsesRouteAfterWrittenEmitsResponseFailed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, EndpointResponses, nil)
_, _ = c.Writer.WriteString(":\n\n")
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
body := w.Body.String()
assert.Contains(t, body, ":\n\n")
assert.Contains(t, body, "event: response.failed\n")
assert.Contains(t, body, `"type":"response.failed"`)
}

View File

@ -145,7 +145,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -266,9 +266,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,

View File

@ -172,11 +172,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // channelService
nil, // resolver
nil, // balanceNotifyService
nil, // userPlatformQuotaRepo
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)

View File

@ -25,7 +25,11 @@ type gatewayModelsResponseForTest struct {
}
type gatewayModelItemForTest struct {
ID string `json:"id"`
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
CreatedAt string `json:"created_at"`
}
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
@ -43,7 +47,7 @@ func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHand
gatewayService: service.NewGatewayService(
repo,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
),
}
}
@ -127,6 +131,267 @@ func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListDisabledKeepsOriginalModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(22)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformOpenAI,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.5": "gpt-5.5",
"gpt-5.4": "gpt-5.4",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: false,
Models: []string{"gpt-5.5"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.4", "gpt-5.5"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListFiltersAndOrdersMappedModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(23)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformOpenAI,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
"gpt-5.5": "gpt-5.5",
"legacy-gpt-2024": "legacy-gpt-2024",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5", "missing-model", "gpt-5.4"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListKeepsConcreteModelAllowedByWildcardMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(26)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-6",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformAnthropic,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"claude-sonnet-4-6"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"claude-sonnet-4-6"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListCanReturnEmptyWhenSelectionsUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(24)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformOpenAI,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Empty(t, modelIDsForTest(got.Data))
}
func TestGatewayModels_CustomModelsListFiltersDefaultFallbackModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(25)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformOpenAI},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5", "legacy-gpt-2024", "gpt-5.4"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
}
func TestGatewayModels_OpenAICustomModelsListKeepsOpenAIResponseShapeForDefaultFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(27)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformOpenAI},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{
ID: groupID,
Platform: service.PlatformOpenAI,
ModelsListConfig: service.GroupModelsListConfig{
Enabled: true,
Models: []string{"gpt-5.5", "gpt-5.4"},
},
},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
require.Equal(t, "model", got.Data[0].Object)
require.NotZero(t, got.Data[0].Created)
require.Equal(t, "openai", got.Data[0].OwnedBy)
require.Empty(t, got.Data[0].CreatedAt)
}
func modelIDsForTest(models []gatewayModelItemForTest) []string {
ids := make([]string, 0, len(models))
for _, model := range models {

View File

@ -247,7 +247,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -527,9 +527,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,

View File

@ -206,7 +206,7 @@ func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, nil),
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{

View File

@ -106,7 +106,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {

View File

@ -0,0 +1,253 @@
package handler
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// Embeddings handles the OpenAI-compatible Embeddings API.
// POST /v1/embeddings
func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
streamStarted := false
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.embeddings",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
reqLog = reqLog.With(zap.String("model", reqModel))
setOpsRequestContext(c, reqModel, false)
setOpsEndpointContext(c, "", int16(service.RequestTypeSync))
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.errorResponse(c, status, code, message)
return
}
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
switchCount := 0
maxAccountSwitches := h.maxAccountSwitches
if maxAccountSwitches <= 0 {
maxAccountSwitches = 3
}
routingStart := time.Now()
for {
selection, _, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
"",
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportHTTPSSE,
false,
)
if err != nil {
reqLog.Warn("openai_embeddings.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, false)
} else {
h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed")
}
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
account := selection.Account
if account.Type != service.AccountTypeAPIKey {
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
failedAccountIDs[account.ID] = struct{}{}
continue
}
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
if !accountAcquired {
return
}
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "")
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, false)
return
}
switchCount++
reqLog.Warn("openai_embeddings.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
if c.Writer.Size() == writerSizeBeforeForward {
h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
}
reqLog.Warn("openai_embeddings.forward_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.embeddings"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_embeddings.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_embeddings.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}

View File

@ -243,7 +243,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -648,7 +648,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
releaseAccountSlot := func() {
if currentAccountRelease != nil {
currentAccountRelease()
currentAccountRelease = nil
}
}
releaseTurnSlots := func() {
releaseAccountSlot()
if currentUserRelease != nil {
currentUserRelease()
currentUserRelease = nil
@ -1233,9 +1236,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
return
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
ensureUserSlotHeld := func() bool {
if currentUserRelease != nil {
return true
}
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return false
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return false
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
return true
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
return
@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
)
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
for {
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
reqLog.Warn("openai.websocket_account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
if selection == nil || selection.Account == nil {
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
releaseAccountSlot()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
h.gatewayService.RecordOpenAIAccountSwitch()
reqLog.Warn("openai.websocket_upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if !ensureUserSlotHeld() {
return
}
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
}
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
@ -1691,6 +1760,15 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// /v1/responses 的严格 SDKCodex CLI要求终止事件必须属于
// response.completed/failed/incomplete/cancelled 集合。
// 通用 `event: error` 帧不被识别为终止事件,会导致
// "stream closed before response.completed"。
if inboundIsResponses(c) {
if writeResponsesFailedSSE(c, errType, message) {
return
}
}
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
@ -1710,9 +1788,17 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
if c == nil || c.Writer == nil {
return false
}
// 旧实现在 Writer.Written 时直接 return false导致 ping 已 flush 之后的
// 上游错误http2 timeout、连接中断等完全无法把错误传给客户端——
// HTTP 200 已锁死TCP 直接 EOFCodex CLI 报 "stream closed before response.completed"。
// 这里改成Writer 已写过时强制走 streamStarted 分支,让
// handleStreamingAwareError 通过 SSE 发协议合规的 response.failed。
if c.Writer.Written() {
streamStarted = true
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
@ -1783,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
_ = conn.CloseNow()
}
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
switch failoverErr.StatusCode {
case http.StatusTooManyRequests:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
case http.StatusUnauthorized, http.StatusForbidden:
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
default:
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
}
}
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
if conn == nil || decision == nil {
return

View File

@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
@ -174,7 +175,11 @@ func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testin
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
// Writer 已写后 ensureForwardErrorResponse 必须仍然把错误信息以 SSE
// 形式追加给客户端streamStarted 强制 true
// 这是 case B 修复:旧实现遇到 Writer.Written 直接 return false
// 客户端只能拿到 silent EOFCodex CLI 报 "stream closed before response.completed"。
func TestOpenAIEnsureForwardErrorResponse_AppendsSSEAfterWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@ -184,9 +189,34 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.True(t, wrote, "must attempt to communicate the failure to the client via SSE")
// 状态码改不了headers 已 flush但 body 应该追加 SSE 错误事件。
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
assert.Contains(t, w.Body.String(), "already written")
// 非 /responses 路径走 legacy event: error 分支。
assert.Contains(t, w.Body.String(), "event: error\n")
}
// case B 回归测试:/responses 路径Writer 已被写过(模拟 ping flushed
// ensureForwardErrorResponse 必须发 response.failed让 Codex 收到合规终止事件。
func TestOpenAIEnsureForwardErrorResponse_ResponsesRouteAfterWrittenEmitsResponseFailed(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, EndpointResponses, nil)
// 模拟 ping 已 flush 的状态Writer 已写过 1 个字节
_, _ = c.Writer.WriteString(":\n\n")
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
body := w.Body.String()
assert.Contains(t, body, ":\n\n", "earlier ping bytes preserved")
assert.Contains(t, body, "event: response.failed\n", "appended a Responses terminal event")
assert.Contains(t, body, `"type":"response.failed"`)
assert.Contains(t, body, `"code":"upstream_error"`)
assert.Contains(t, body, "Upstream request failed")
}
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
@ -266,7 +296,9 @@ func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
assert.Equal(t, "", w.Body.String())
}
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
// Panic 在已 flush 的 /v1/responses 流中:状态码无法改(已 written
// 但 body 应追加 response.failed 让客户端识别为合规截断而不是 silent EOF。
func TestOpenAIRecoverResponsesPanic_AppendsResponseFailedAfterWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
@ -284,7 +316,9 @@ func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T
})
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
body := w.Body.String()
assert.Contains(t, body, "already written")
assert.Contains(t, body, "event: response.failed\n")
}
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
@ -707,16 +741,31 @@ func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key st
}
type contentModerationHandlerTestRepo struct {
mu sync.Mutex
logs []service.ContentModerationLog
}
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
if log != nil {
r.mu.Lock()
defer r.mu.Unlock()
r.logs = append(r.logs, *log)
}
return nil
}
func (r *contentModerationHandlerTestRepo) resetLogs() {
r.mu.Lock()
defer r.mu.Unlock()
r.logs = nil
}
func (r *contentModerationHandlerTestRepo) logSnapshot() []service.ContentModerationLog {
r.mu.Lock()
defer r.mu.Unlock()
return append([]service.ContentModerationLog(nil), r.logs...)
}
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@ -775,7 +824,10 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T
})
require.NoError(t, err)
require.True(t, decision.Blocked)
repo.logs = nil
require.Eventually(t, func() bool {
return len(repo.logSnapshot()) == 1
}, time.Second, 10*time.Millisecond)
repo.resetLogs()
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
@ -815,10 +867,11 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
}
require.Len(t, repo.logs, 1)
require.True(t, repo.logs[0].Flagged)
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
logs := repo.logSnapshot()
require.Len(t, logs, 1)
require.True(t, logs[0].Flagged)
require.Equal(t, service.ContentModerationActionBlock, logs[0].Action)
require.Equal(t, "bad prompt", logs[0].InputExcerpt)
}
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
@ -1042,6 +1095,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
return &account, nil
}
type openAIWSFailoverHandlerAccountRepoStub struct {
service.AccountRepository
accounts []service.Account
rateLimitedIDs []int64
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
out := make([]service.Account, 0, len(s.accounts))
for _, account := range s.accounts {
if account.Platform == platform && account.IsSchedulable() {
out = append(out, account)
}
}
return out, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
for _, account := range s.accounts {
if account.ID == id {
acc := account
return &acc, nil
}
}
return nil, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
for i := range s.accounts {
if s.accounts[i].ID == id {
reset := resetAt
s.accounts[i].RateLimitResetAt = &reset
break
}
}
return nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository
created chan *service.UsageLog
@ -1074,6 +1173,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
return out, nil
}
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
firstHitCh := make(chan []byte, 1)
secondHitCh := make(chan []byte, 1)
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
firstHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
cancelWrite()
}))
defer firstUpstream.Close()
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
secondHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
cancelWrite()
_ = conn.Close(coderws.StatusNormalClosure, "done")
}))
defer secondUpstream.Close()
groupID := int64(4202)
accounts := []service.Account{
{
ID: 9902,
Name: "openai-ws-rate-limited",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"api_key": "sk-first",
"base_url": firstUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
{
ID: 9903,
Name: "openai-ws-healthy",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 2,
Credentials: map[string]any{
"api_key": "sk-second",
"base_url": secondUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
cfg.Gateway.MaxAccountSwitches = 3
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
rateLimitSvc,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
nil,
nil,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
maxAccountSwitches: 3,
}
apiKey := &service.APIKey{
ID: 1802,
GroupID: &groupID,
User: &service.User{ID: 1702, Status: service.StatusActive},
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
select {
case <-firstHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第一个上游收到首帧超时")
}
select {
case <-secondHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第二个上游收到重放首帧超时")
}
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper()
gin.SetMode(gin.TestMode)
@ -1168,7 +1462,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
}, nil, nil, nil)
}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
usageRepo,
@ -1190,6 +1484,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
channelSvc,
nil,
nil,
nil, // userPlatformQuotaRepo
)
cache := &concurrencyCacheMock{

View File

@ -123,7 +123,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {

View File

@ -53,6 +53,8 @@ const (
opsCodeUserNotFound = "USER_NOT_FOUND"
opsCodeAPIKeyQuotaExhausted = "API_KEY_QUOTA_EXHAUSTED"
opsCodeAPIKeyQueryDeprecated = "api_key_in_query_deprecated"
opsCodeGroupDeleted = "GROUP_DELETED"
opsCodeGroupDisabled = "GROUP_DISABLED"
)
const (
@ -1012,6 +1014,8 @@ func parseOpsErrorResponse(body []byte) parsedOpsError {
var code string
if v, ok := errObj["code"]; ok {
switch n := v.(type) {
case string:
code = strings.TrimSpace(n)
case float64:
code = strconvItoa(int(n))
case int:
@ -1190,14 +1194,19 @@ func isOpsClientAuthError(code string, msg string) bool {
opsCodeAPIKeyExpired,
opsCodeAPIKeyDisabled,
opsCodeUserNotFound,
opsCodeUserInactive:
opsCodeUserInactive,
opsCodeGroupDeleted,
opsCodeGroupDisabled:
return true
}
return strings.Contains(msg, "invalid api key") ||
strings.Contains(msg, "api key is required") ||
strings.Contains(msg, "api key is disabled") ||
strings.Contains(msg, "user associated with api key not found") ||
strings.Contains(msg, "user account is not active")
strings.Contains(msg, "user account is not active") ||
strings.Contains(msg, "api key 所属分组已删除") ||
strings.Contains(msg, "api key 所属分组已停用") ||
strings.Contains(msg, "api key is not assigned to any group")
}
func isOpsLocalBusinessLimitError(code string, msg string) bool {
@ -1213,6 +1222,7 @@ func isOpsLocalBusinessLimitError(code string, msg string) bool {
return strings.Contains(msg, "api key in query parameter is deprecated") ||
strings.Contains(msg, "query parameter api_key is deprecated") ||
strings.Contains(msg, "no active subscription found for this group") ||
strings.Contains(msg, "subscription is invalid or expired") ||
strings.Contains(msg, opsErrInsufficientBalance) ||
strings.Contains(msg, "insufficient account balance") ||
strings.Contains(msg, "api key group platform is not gemini") ||
@ -1223,7 +1233,22 @@ func isOpsLocalBusinessLimitError(code string, msg string) bool {
strings.Contains(msg, "daily usage limit exceeded") ||
strings.Contains(msg, "weekly usage limit exceeded") ||
strings.Contains(msg, "monthly usage limit exceeded") ||
strings.Contains(msg, "requests-per-minute limit exceeded")
strings.Contains(msg, "usage quota exhausted for this platform") ||
strings.Contains(msg, "requests-per-minute limit exceeded") ||
strings.Contains(msg, "too many pending requests") ||
strings.Contains(msg, "concurrency limit exceeded") ||
strings.Contains(msg, "image generation concurrency limit exceeded") ||
strings.Contains(msg, "this group is restricted to claude code clients") ||
strings.Contains(msg, "this group does not allow /v1/messages dispatch") ||
strings.Contains(msg, "image generation is not enabled for this group") ||
strings.Contains(msg, "token counting is not supported for this platform") ||
strings.Contains(msg, "images api is not supported for this platform") ||
(strings.Contains(msg, "model ") && strings.Contains(msg, " not in whitelist")) ||
(strings.Contains(msg, "beta feature ") && strings.Contains(msg, " is not allowed")) ||
(strings.Contains(msg, "openai service_tier=") && strings.Contains(msg, " is not allowed for model")) ||
strings.Contains(msg, "this account only allows codex official clients") ||
strings.Contains(msg, "openai wsv1 is temporarily unsupported") ||
strings.Contains(msg, "openai codex passthrough requires a non-empty instructions field")
}
func hasOpsUpstreamErrorContext(c *gin.Context) bool {

View File

@ -288,6 +288,34 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
code: "USER_INACTIVE",
status: http.StatusUnauthorized,
},
{
name: "deleted local API key group",
errType: "api_error",
message: "API Key 所属分组已删除",
code: "GROUP_DELETED",
status: http.StatusForbidden,
},
{
name: "disabled local API key group",
errType: "api_error",
message: "API Key 所属分组已停用",
code: "GROUP_DISABLED",
status: http.StatusForbidden,
},
{
name: "google deleted API key group message without semantic code",
errType: "api_error",
message: "API Key 所属分组已删除",
code: "403",
status: http.StatusForbidden,
},
{
name: "anthropic unassigned API key group",
errType: "permission_error",
message: "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.",
code: "",
status: http.StatusForbidden,
},
{
name: "google invalid API key",
errType: "api_error",
@ -389,6 +417,15 @@ func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) {
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "gateway subscription invalid cache recheck",
errType: "billing_error",
message: "subscription is invalid or expired",
code: "billing_error",
status: http.StatusForbidden,
wantErrType: "billing_error",
wantPhase: "request",
},
{
name: "google insufficient account balance",
errType: "api_error",
@ -443,6 +480,132 @@ func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) {
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "user platform daily quota exhausted",
errType: "api_error",
message: "Daily usage quota exhausted for this platform.",
code: "rate_limit_exceeded",
status: http.StatusTooManyRequests,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "local pending queue limit",
errType: "rate_limit_error",
message: "Too many pending requests, please retry later",
code: "",
status: http.StatusTooManyRequests,
wantErrType: "rate_limit_error",
wantPhase: "request",
},
{
name: "local concurrency limit",
errType: "rate_limit_error",
message: "Concurrency limit exceeded for user, please retry later",
code: "",
status: http.StatusTooManyRequests,
wantErrType: "rate_limit_error",
wantPhase: "request",
},
{
name: "group claude code only feature gate",
errType: "permission_error",
message: "This group is restricted to Claude Code clients (/v1/messages only)",
code: "",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "group image generation feature gate",
errType: "permission_error",
message: "Image generation is not enabled for this group",
code: "",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "route token counting platform unsupported",
errType: "not_found_error",
message: "Token counting is not supported for this platform",
code: "",
status: http.StatusNotFound,
wantErrType: "not_found_error",
wantPhase: "request",
},
{
name: "route images API platform unsupported",
errType: "not_found_error",
message: "Images API is not supported for this platform",
code: "",
status: http.StatusNotFound,
wantErrType: "not_found_error",
wantPhase: "request",
},
{
name: "antigravity model whitelist feature gate",
errType: "permission_error",
message: "model claude-3-5-sonnet not in whitelist",
code: "",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "google antigravity model whitelist feature gate",
errType: "api_error",
message: "model gemini-2.5-pro not in whitelist",
code: "403",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "claude beta policy block",
errType: "invalid_request_error",
message: "beta feature interleaved-thinking-2025-05-14 is not allowed",
code: "",
status: http.StatusBadRequest,
wantErrType: "invalid_request_error",
wantPhase: "request",
},
{
name: "openai fast policy block",
errType: "permission_error",
message: "openai service_tier=priority is not allowed for model gpt-5.5",
code: "",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "codex official client policy block",
errType: "forbidden_error",
message: "This account only allows Codex official clients",
code: "",
status: http.StatusForbidden,
wantErrType: "forbidden_error",
wantPhase: "request",
},
{
name: "openai wsv1 unsupported feature gate",
errType: "invalid_request_error",
message: "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
code: "",
status: http.StatusBadRequest,
wantErrType: "invalid_request_error",
wantPhase: "request",
},
{
name: "openai passthrough instructions policy block",
errType: "forbidden_error",
message: "OpenAI codex passthrough requires a non-empty instructions field",
code: "",
status: http.StatusForbidden,
wantErrType: "forbidden_error",
wantPhase: "request",
},
}
for _, tt := range tests {
@ -479,6 +642,22 @@ func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
require.Equal(t, "client_request", errorSource)
}
func TestClassifyOpsClientBusinessLimitedMarkerExcludesCustomPolicyDenialFromSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied)
errType := normalizeOpsErrorType("invalid_request_error", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "custom admin policy message", "", http.StatusBadRequest)
require.Equal(t, "invalid_request_error", errType)
require.Equal(t, "auth", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "client", errorOwner)
require.Equal(t, "client_request", errorSource)
}
func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@ -583,6 +762,78 @@ func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
code: "API_KEY_QUOTA_EXHAUSTED",
status: http.StatusTooManyRequests,
},
{
name: "provider deleted group shaped error",
message: "API Key 所属分组已删除",
code: "GROUP_DELETED",
status: http.StatusForbidden,
},
{
name: "provider unassigned group shaped error",
message: "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.",
code: "403",
status: http.StatusForbidden,
},
{
name: "provider local quota shaped error",
message: "Daily usage quota exhausted for this platform.",
code: "rate_limit_exceeded",
status: http.StatusTooManyRequests,
},
{
name: "provider feature gate shaped error",
message: "Image generation is not enabled for this group",
code: "403",
status: http.StatusForbidden,
},
{
name: "provider token counting unsupported shaped error",
message: "Token counting is not supported for this platform",
code: "404",
status: http.StatusNotFound,
},
{
name: "provider image API unsupported shaped error",
message: "Images API is not supported for this platform",
code: "404",
status: http.StatusNotFound,
},
{
name: "provider antigravity whitelist shaped error",
message: "model claude-3-5-sonnet not in whitelist",
code: "403",
status: http.StatusForbidden,
},
{
name: "provider beta policy shaped error",
message: "beta feature interleaved-thinking-2025-05-14 is not allowed",
code: "400",
status: http.StatusBadRequest,
},
{
name: "provider openai fast policy shaped error",
message: "openai service_tier=priority is not allowed for model gpt-5.5",
code: "403",
status: http.StatusForbidden,
},
{
name: "provider codex client policy shaped error",
message: "This account only allows Codex official clients",
code: "403",
status: http.StatusForbidden,
},
{
name: "provider wsv1 unsupported shaped error",
message: "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
code: "400",
status: http.StatusBadRequest,
},
{
name: "provider passthrough instructions shaped error",
message: "OpenAI codex passthrough requires a non-empty instructions field",
code: "403",
status: http.StatusForbidden,
},
}
for _, tt := range tests {
@ -628,6 +879,14 @@ func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
require.Equal(t, "upstream_http", errorSource)
}
func TestParseOpsErrorResponsePreservesNestedStringCode(t *testing.T) {
parsed := parseOpsErrorResponse([]byte(`{"error":{"type":"permission_error","code":"GROUP_DELETED","message":"API Key 所属分组已删除"}}`))
require.Equal(t, "permission_error", parsed.ErrorType)
require.Equal(t, "GROUP_DELETED", parsed.Code)
require.Equal(t, "API Key 所属分组已删除", parsed.Message)
}
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@ -0,0 +1,104 @@
// Package quotaview provides shared quota response helpers for user and admin handlers.
// Extracted to avoid import cycles between handler and handler/admin packages.
package quotaview
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// LazyZeroQuotaForResponse 按 D14 规则把过期档位归零(不写 DB
// includeWindowStart=true 时输出 *_window_start 字段admin 视角调试用)
func LazyZeroQuotaForResponse(r service.UserPlatformQuotaRecord, now time.Time, includeWindowStart bool) map[string]any {
daily := buildWindowSlice(r.DailyUsageUSD, r.DailyLimitUSD, r.DailyWindowStart, NeedsDailyReset(r.DailyWindowStart, now), nextDailyResetTime(now), includeWindowStart)
weekly := buildWindowSlice(r.WeeklyUsageUSD, r.WeeklyLimitUSD, r.WeeklyWindowStart, NeedsWeeklyReset(r.WeeklyWindowStart, now), nextWeeklyResetTime(now), includeWindowStart)
monthly := buildWindowSlice(r.MonthlyUsageUSD, r.MonthlyLimitUSD, r.MonthlyWindowStart, NeedsMonthlyReset(r.MonthlyWindowStart, now), NextMonthlyResetTimeFrom(r.MonthlyWindowStart, now), includeWindowStart)
out := map[string]any{
"platform": r.Platform,
"daily_usage_usd": daily.usage,
"daily_limit_usd": daily.limit,
"daily_window_resets_at": daily.resetsAt,
"weekly_usage_usd": weekly.usage,
"weekly_limit_usd": weekly.limit,
"weekly_window_resets_at": weekly.resetsAt,
"monthly_usage_usd": monthly.usage,
"monthly_limit_usd": monthly.limit,
"monthly_window_resets_at": monthly.resetsAt,
}
if includeWindowStart {
out["daily_window_start"] = daily.windowStart
out["weekly_window_start"] = weekly.windowStart
out["monthly_window_start"] = monthly.windowStart
}
return out
}
type windowSlice struct {
usage float64
limit *float64
resetsAt *string
windowStart *string
}
func buildWindowSlice(usage float64, limit *float64, start *time.Time, expired bool, nextReset time.Time, includeStart bool) windowSlice {
out := windowSlice{usage: usage, limit: limit}
if expired {
out.usage = 0
out.resetsAt = nil
} else if start != nil {
s := nextReset.Format(time.RFC3339)
out.resetsAt = &s
}
if includeStart && start != nil {
s := start.Format(time.RFC3339)
out.windowStart = &s
}
return out
}
// NeedsDailyReset 判断日窗口是否已过期start 早于「全局时区当天 0 点」即过期。
// 时区跟随 timezone.Location()(全局服务器时区),与 billing / repo 写入的 window_start 同口径。
func NeedsDailyReset(start *time.Time, now time.Time) bool {
if start == nil {
return false
}
return start.Before(timezone.StartOfDay(now))
}
func NeedsWeeklyReset(start *time.Time, now time.Time) bool {
if start == nil {
return false
}
return start.Before(timezone.StartOfWeek(now))
}
// NeedsMonthlyReset 30 天滚动窗口语义(与订阅模式 NeedsMonthlyReset 一致)。
func NeedsMonthlyReset(start *time.Time, now time.Time) bool {
if start == nil {
return false
}
return now.Sub(*start) >= 30*24*time.Hour
}
func nextDailyResetTime(now time.Time) time.Time {
return timezone.StartOfDay(now).AddDate(0, 0, 1)
}
func nextWeeklyResetTime(now time.Time) time.Time {
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
}
// NextMonthlyResetTimeFrom 计算 30 天滚动月度窗口的下次重置时间。
// 语义:
// - start != nil → 返回 start + 30d与 billing_cache_service.nextMonthlyResetFrom 一致)
// - start == nil → 退化为 now + 30d保留旧行为避免 nil 崩溃)
//
// 导出(首字母大写)以允许测试直接调用。
func NextMonthlyResetTimeFrom(start *time.Time, now time.Time) time.Time {
if start == nil {
return now.Add(30 * 24 * time.Hour)
}
return start.Add(30 * 24 * time.Hour)
}

View File

@ -0,0 +1,133 @@
package quotaview
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TestNextMonthlyResetTimeFrom_FromStart 验证start 已知时返回 start+30d不随 now 漂移。
func TestNextMonthlyResetTimeFrom_FromStart(t *testing.T) {
t0 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
now := t0.Add(15 * 24 * time.Hour) // t0 + 15d
want := t0.Add(30 * 24 * time.Hour) // t0 + 30d
got := NextMonthlyResetTimeFrom(&t0, now)
if !got.Equal(want) {
t.Errorf("NextMonthlyResetTimeFrom: want %v, got %v", want, got)
}
}
// TestNextMonthlyResetTimeFrom_NilStart 验证start=nil 时退化为 now+30d不 panic
func TestNextMonthlyResetTimeFrom_NilStart(t *testing.T) {
now := time.Date(2024, 3, 15, 12, 0, 0, 0, time.UTC)
want := now.Add(30 * 24 * time.Hour)
got := NextMonthlyResetTimeFrom(nil, now)
if !got.Equal(want) {
t.Errorf("NextMonthlyResetTimeFrom(nil): want %v, got %v", want, got)
}
}
// TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting 验证:
// 连续两次以不同 now 调用、但 MonthlyWindowStart 相同的 record
// monthly_window_resets_at 始终等于 windowStart+30d不随 now 漂移。
func TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting(t *testing.T) {
windowStart := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
wantResetsAt := windowStart.Add(30 * 24 * time.Hour).Format(time.RFC3339)
r := service.UserPlatformQuotaRecord{
Platform: "openai",
MonthlyUsageUSD: 5.0,
MonthlyWindowStart: &windowStart,
}
// 第一次调用now = windowStart + 5d
now1 := windowStart.Add(5 * 24 * time.Hour)
out1 := LazyZeroQuotaForResponse(r, now1, false)
resetsAt1, ok1 := out1["monthly_window_resets_at"]
if !ok1 || resetsAt1 == nil {
t.Fatal("first call: monthly_window_resets_at should be set for active window")
}
s1, ok := resetsAt1.(*string)
if !ok || s1 == nil {
t.Fatalf("first call: monthly_window_resets_at should be *string, got %T", resetsAt1)
}
if *s1 != wantResetsAt {
t.Errorf("first call: want %s, got %s", wantResetsAt, *s1)
}
// 第二次调用now = windowStart + 10d不同 now但 resetsAt 应不变)
now2 := windowStart.Add(10 * 24 * time.Hour)
out2 := LazyZeroQuotaForResponse(r, now2, false)
resetsAt2, ok2 := out2["monthly_window_resets_at"]
if !ok2 || resetsAt2 == nil {
t.Fatal("second call: monthly_window_resets_at should be set for active window")
}
s2, ok := resetsAt2.(*string)
if !ok || s2 == nil {
t.Fatalf("second call: monthly_window_resets_at should be *string, got %T", resetsAt2)
}
if *s2 != wantResetsAt {
t.Errorf("second call: want %s, got %s", wantResetsAt, *s2)
}
// 两次结果必须相等
if *s1 != *s2 {
t.Errorf("resetsAt drifted between calls: %s vs %s", *s1, *s2)
}
}
// TestNeedsDailyReset_FollowsServerTimezone 验证日窗口过期判断按全局时区(北京 0 点)而非 UTC。
func TestNeedsDailyReset_FollowsServerTimezone(t *testing.T) {
if err := timezone.Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = timezone.Init("UTC") })
// now = 2026-05-25 23:00 UTC = 2026-05-26 07:00 +08北京 5/26
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC)
// start = 2026-05-25 10:00 UTC = 2026-05-25 18:00 +08北京 5/25→ 应判定为过期
startPrevBeijingDay := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC)
if !NeedsDailyReset(&startPrevBeijingDay, now) {
t.Error("上一个北京日的窗口应判定为过期")
}
// start = 2026-05-25 20:00 UTC = 2026-05-26 04:00 +08北京 5/26 同日)→ 不应过期
startSameBeijingDay := time.Date(2026, 5, 25, 20, 0, 0, 0, time.UTC)
if NeedsDailyReset(&startSameBeijingDay, now) {
t.Error("同一北京日的窗口不应判定为过期")
}
}
// TestNextDailyResetTime_FollowsServerTimezone 验证下次日重置 = 次日北京 0 点。
func TestNextDailyResetTime_FollowsServerTimezone(t *testing.T) {
if err := timezone.Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = timezone.Init("UTC") })
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00
want := time.Date(2026, 5, 27, 0, 0, 0, 0, timezone.Location()) // 北京 5/27 00:00
if got := nextDailyResetTime(now); !got.Equal(want) {
t.Errorf("nextDailyResetTime = %v, want %v", got, want)
}
}
// TestNextWeeklyResetTime_FollowsServerTimezone 验证下次周重置 = 下周一北京 0 点。
func TestNextWeeklyResetTime_FollowsServerTimezone(t *testing.T) {
if err := timezone.Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = timezone.Init("UTC") })
// 北京 2026-05-26周二→ 下周一是 2026-06-01
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 周二
want := time.Date(2026, 6, 1, 0, 0, 0, 0, timezone.Location())
if got := nextWeeklyResetTime(now); !got.Equal(want) {
t.Errorf("nextWeeklyResetTime = %v, want %v", got, want)
}
}

View File

@ -0,0 +1,165 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// responsesFailedError 对齐 OpenAI Responses 协议 error 子对象。
type responsesFailedError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// responsesFailedBody 对齐 apicompat.makeResponsesCompletedEvent 输出的 response 子对象字段集。
// Output 用空 slice不是 nil确保 marshal 为 `[]` 而非 `null`。
type responsesFailedBody struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model,omitempty"`
Status string `json:"status"`
Output []any `json:"output"`
Error responsesFailedError `json:"error"`
}
// responsesFailedEvent 是写入 SSE data 行的顶层结构。
// 故意不带 sequence_numberspec 标记可选,且本函数被调用时无法可靠拿到 last seq。
type responsesFailedEvent struct {
Type string `json:"type"`
Response responsesFailedBody `json:"response"`
}
// writeResponsesFailedSSE emits a `response.failed` SSE event in the OpenAI
// Responses API protocol after the stream has already started.
//
// 必要性:一旦 SSE 头和任意数据(例如等待槽位时的 ping comment已经 flush
// HTTP 200 状态码就被固化。此后若网关需要回报错误,只能继续通过 SSE 事件传达。
// 通用的 `event: error` 帧不是 Responses 协议规定的终止事件,
// Codex CLI 等严格 SDK 会因为没收到 `response.completed/failed/incomplete/cancelled`
// 而抛出 "stream closed before response.completed"。
//
// 字段集对齐 apicompat.makeResponsesCompletedEventid/object/model/status/output/error。
// 故意不写 sequence_number本函数被调用时无法可靠拿到当前流的 last sequence
// 而 OpenAI spec 将 sequence_number 设为可选;省略避免破坏单调性约束。
//
// 返回 true 表示已尝试 SSE 写出(不论 Write 是否成功caller 都应直接 return
// 返回 false 表示 writer 不支持 Flusher无法以 SSE 形式回报错误;
// 此时 caller 也无法回退到 JSONHTTP 200 已固化),通常意味着连接已经损坏,
// 应当让请求处理函数 return由上层关闭连接。
func writeResponsesFailedSSE(c *gin.Context, errType, message string) bool {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return false
}
payload, err := json.Marshal(responsesFailedEvent{
Type: "response.failed",
Response: responsesFailedBody{
ID: synthesizeResponseID(c),
Object: "response",
Model: requestModel(c),
Status: "failed",
Output: []any{},
Error: responsesFailedError{
Code: mapResponsesErrorCode(errType),
Message: message,
},
},
})
if err != nil {
_ = c.Error(err)
return true
}
if _, err := fmt.Fprintf(c.Writer, "event: response.failed\ndata: %s\n\n", payload); err != nil {
_ = c.Error(err)
return true
}
flusher.Flush()
return true
}
// inboundIsResponses 判断当前请求是否落在任何 /responses 路由上。
//
// 不能直接用 GetInboundEndpoint(c) == EndpointResponses 比较,因为
// NormalizeInboundEndpoint 只识别包含 "/v1/responses" 子串的路径;
// 项目里实际注册了多组路由gateway_v1、top-level bare、codex direct
// 其中 r.POST("/responses", ...) 和 codexDirect.POST("/responses", ...)
// 的 c.FullPath() 不含 "/v1/" 前缀,会被归一化为原始路径,
// 导致协议合规终止事件没法发出去。
//
// 这里用 FullPath 的后缀判断,覆盖所有变体:
// - /v1/responses
// - /v1/responses/compact
// - /responses
// - /responses/compact
// - /backend-api/codex/responses
// - /backend-api/codex/responses/compact
func inboundIsResponses(c *gin.Context) bool {
if c == nil {
return false
}
p := strings.TrimRight(c.FullPath(), "/")
if p == "" && c.Request != nil && c.Request.URL != nil {
p = strings.TrimRight(c.Request.URL.Path, "/")
}
if p == "" {
return false
}
return strings.HasSuffix(p, "/responses") || strings.Contains(p, "/responses/")
}
// synthesizeResponseID 为合成的 response.failed 事件生成一个稳定的 id。
// 优先复用 server 端生成的 request_id存在 request.Context 里,由 request_logger 写入),
// 以便客户端报错能与 server 日志关联;缺失时回退 uuid。
func synthesizeResponseID(c *gin.Context) string {
if c != nil && c.Request != nil {
if rid, ok := c.Request.Context().Value(ctxkey.RequestID).(string); ok {
if rid = strings.TrimSpace(rid); rid != "" {
return "resp_" + strings.ReplaceAll(rid, "-", "")
}
}
}
return "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
// requestModel 取当前请求的 inbound model由 setOpsRequestContext 写入)。
// 缺失时返回 ""caller 据此决定是否忽略该字段。
func requestModel(c *gin.Context) string {
if c == nil {
return ""
}
if v, ok := c.Get(opsModelKey); ok {
if s, ok := v.(string); ok {
return strings.TrimSpace(s)
}
}
return ""
}
// mapResponsesErrorCode 把内部 errType 映射为 Responses 协议常见的 error.code。
// 无明确映射时原样返回,保证至少可读。
func mapResponsesErrorCode(errType string) string {
switch errType {
case "rate_limit_error":
return "rate_limit_exceeded"
case "invalid_request_error":
return "invalid_request"
case "permission_error":
return "permission_denied"
case "authentication_error":
return "authentication_failed"
case "upstream_error":
return "upstream_error"
case "server_error", "api_error", "":
return "server_error"
default:
return errType
}
}

View File

@ -0,0 +1,253 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Regression for the production incident on 2026-05-24 around 9:13 CST:
// user 16 sent /v1/responses with stream:true via Codex CLI; the user-concurrency
// slot wait sent SSE ping comments (flushing HTTP 200 + headers), then the 30s
// timeout fired and the handler emitted `event: error\ndata: {...}`. Codex CLI
// does not recognize that as a Responses terminal event and reports
// "stream closed before response.completed". The fix is to emit a synthetic
// response.failed event when the inbound endpoint is /v1/responses.
func newGinContextForEndpoint(t *testing.T, endpoint string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, endpoint, nil)
return c, w
}
// parseResponsesFailedSSE 抽出 SSE 中 data 行的 JSON返回 (response 对象, error 对象)。
func parseResponsesFailedSSE(t *testing.T, body string) (map[string]any, map[string]any) {
t.Helper()
require.True(t, strings.HasPrefix(body, "event: response.failed\n"),
"expect event: response.failed prefix, got: %q", body)
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.SplitN(strings.TrimSuffix(body, "\n\n"), "\n", 2)
require.Len(t, lines, 2)
require.True(t, strings.HasPrefix(lines[1], "data: "))
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data must be valid JSON: %s", jsonStr)
assert.Equal(t, "response.failed", parsed["type"])
// 故意不发 sequence_number避免与后续真实事件的序号冲突。
_, hasSeq := parsed["sequence_number"]
assert.False(t, hasSeq, "synthetic event must not emit sequence_number")
resp, ok := parsed["response"].(map[string]any)
require.True(t, ok, "response object missing")
assert.Equal(t, "response", resp["object"])
assert.Equal(t, "failed", resp["status"])
errObj, ok := resp["error"].(map[string]any)
require.True(t, ok, "error object missing")
return resp, errObj
}
// OpenAI handler: /v1/responses streaming, after stream started, must emit response.failed.
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingEmitsResponseFailed(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointResponses)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
"Concurrency limit exceeded for user, please retry later", true)
resp, errObj := parseResponsesFailedSSE(t, w.Body.String())
id, _ := resp["id"].(string)
assert.True(t, strings.HasPrefix(id, "resp_"), "id should start with resp_, got %q", id)
assert.Equal(t, "rate_limit_exceeded", errObj["code"])
assert.Equal(t, "Concurrency limit exceeded for user, please retry later", errObj["message"])
}
// 当 setOpsRequestContext 写过 model合成事件应回填该字段与 codebase 已有 makeResponsesCompletedEvent 对齐)。
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingIncludesModel(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointResponses)
setOpsRequestContext(c, "gpt-5.5", true)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
resp, _ := parseResponsesFailedSSE(t, w.Body.String())
assert.Equal(t, "gpt-5.5", resp["model"])
}
// 没有 model 时 model 字段不应出现(避免发空字符串污染下游解析)。
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingOmitsEmptyModel(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointResponses)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
resp, _ := parseResponsesFailedSSE(t, w.Body.String())
_, hasModel := resp["model"]
assert.False(t, hasModel, "model field must be omitted when unknown")
}
// 当 request.Context 携带 ctxkey.RequestID 时,合成 id 应与之关联,便于和 server log 串起来。
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingReusesRequestID(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointResponses)
c.Request = c.Request.WithContext(
context.WithValue(c.Request.Context(), ctxkey.RequestID, "fd277bc5-ff7e-45d1-8aa9-f54e1df318f1"),
)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "x", true)
resp, _ := parseResponsesFailedSSE(t, w.Body.String())
assert.Equal(t, "resp_fd277bc5ff7e45d18aa9f54e1df318f1", resp["id"])
}
// 与旧分支的 TestOpenAIHandleStreamingAwareError_JSONEscaping 对齐:
// 新的 response.failed payload 也必须正确转义 message 里的特殊字符,
// 否则下游 SDK 解析 JSON 时会失败。
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingJSONEscaping(t *testing.T) {
cases := []struct {
name string
errType string
message string
}{
{"双引号", "server_error", `upstream returned "invalid" response`},
{"反斜杠", "server_error", `path C:\Users\test\file.txt not found`},
{"双引号+反斜杠", "upstream_error", `error parsing "key\value": unexpected token`},
{"换行与制表", "server_error", "line1\nline2\ttab"},
{"普通", "upstream_error", "Upstream service temporarily unavailable"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointResponses)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tc.errType, tc.message, true)
_, errObj := parseResponsesFailedSSE(t, w.Body.String())
assert.Equal(t, tc.message, errObj["message"], "message 必须被原样还原")
})
}
}
// OpenAI handler: /v1/chat/completions streaming keeps the legacy event: error format
// (out of scope for this fix; covered to prevent regression of unrelated paths).
func TestOpenAIHandleStreamingAwareError_ChatCompletionsStreamingKeepsLegacy(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointChatCompletions)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
body := w.Body.String()
assert.True(t, strings.HasPrefix(body, "event: error\n"), "got: %q", body)
}
// Gateway (Anthropic-backed) handler: /v1/responses path also must emit response.failed.
func TestGatewayHandleStreamingAwareError_ResponsesStreamingEmitsResponseFailed(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointResponses)
h := &GatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "upstream gone", true)
_, errObj := parseResponsesFailedSSE(t, w.Body.String())
assert.Equal(t, "upstream_error", errObj["code"])
assert.Equal(t, "upstream gone", errObj["message"])
}
// Gateway handler: /v1/messages preserves the legacy data:{type:error,...} format
// (Anthropic spec accepts a type:"error" stream event).
func TestGatewayHandleStreamingAwareError_MessagesStreamingKeepsLegacy(t *testing.T) {
c, w := newGinContextForEndpoint(t, EndpointMessages)
h := &GatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
body := w.Body.String()
assert.True(t, strings.HasPrefix(body, `data: {"type":"error"`), "got: %q", body)
}
// 项目里 /responses 注册在多组路由:/v1/responsesgateway、裸 /responsestop-level
// /backend-api/codex/responsescodex direct。我们 fix 必须覆盖全部,
// 否则一些客户端走的路径就不会发 response.failed照样报 stream closed。
// 这是生产 2026-05-24 ~11:05 UTC user 16 实际命中的 bug。
func TestInboundIsResponses_CoversAllRoutes(t *testing.T) {
cases := []struct {
route string
want bool
}{
{"/v1/responses", true},
{"/v1/responses/compact", true},
{"/responses", true}, // <-- 用户 16 实际走这条
{"/responses/compact", true},
{"/backend-api/codex/responses", true},
{"/backend-api/codex/responses/compact", true},
{"/v1/chat/completions", false},
{"/v1/messages", false},
{"/", false},
{"/responses-fake", false},
}
for _, tc := range cases {
t.Run(tc.route, func(t *testing.T) {
c, _ := newGinContextForEndpoint(t, tc.route)
assert.Equal(t, tc.want, inboundIsResponses(c), "route=%q", tc.route)
})
}
}
// 用 c.Request.URL.Path 作为 fallback当 c.FullPath() 为空时,例如某些测试 fixture
func TestInboundIsResponses_FallsBackToURLPath(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/responses", nil)
// 这种情况下 c.FullPath() 是 "",必须 fallback 到 URL.Path
assert.True(t, inboundIsResponses(c), "URL.Path fallback must work when FullPath is empty")
}
// 回归生产事故:用户 16 走 /responses 路径,必须发 response.failed。
func TestOpenAIHandleStreamingAwareError_BareResponsesRouteEmitsResponseFailed(t *testing.T) {
c, w := newGinContextForEndpoint(t, "/responses")
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
"Concurrency limit exceeded for user, please retry later", true)
resp, errObj := parseResponsesFailedSSE(t, w.Body.String())
id, _ := resp["id"].(string)
assert.True(t, strings.HasPrefix(id, "resp_"))
assert.Equal(t, "rate_limit_exceeded", errObj["code"])
}
// Synthesized response.failed id falls back to uuid when no request_id is present.
func TestSynthesizeResponseID_FallbackUUID(t *testing.T) {
c, _ := newGinContextForEndpoint(t, EndpointResponses)
id := synthesizeResponseID(c)
assert.True(t, strings.HasPrefix(id, "resp_"))
// uuid 去掉短横线后 32 hex 字符;前缀 "resp_" 共 37。
assert.Len(t, id, 37)
}
func TestMapResponsesErrorCode(t *testing.T) {
cases := []struct{ in, out string }{
{"rate_limit_error", "rate_limit_exceeded"},
{"invalid_request_error", "invalid_request"},
{"permission_error", "permission_denied"},
{"authentication_error", "authentication_failed"},
{"upstream_error", "upstream_error"},
{"server_error", "server_error"},
{"api_error", "server_error"},
{"", "server_error"},
{"custom_thing", "custom_thing"},
}
for _, tc := range cases {
assert.Equal(t, tc.out, mapResponsesErrorCode(tc.in), "in=%q", tc.in)
}
}

View File

@ -3,8 +3,10 @@ package handler
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@ -14,11 +16,12 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
userPlatformQuotaRepo service.UserPlatformQuotaRepository
}
// NewUserHandler creates a new UserHandler
@ -28,16 +31,44 @@ func NewUserHandler(
emailService *service.EmailService,
emailCache service.EmailCache,
affiliateService *service.AffiliateService,
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
) *UserHandler {
return &UserHandler{
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
userPlatformQuotaRepo: userPlatformQuotaRepo,
}
}
// GetMyPlatformQuotas GET /user/platform-quotas
// 返回当前 JWT 用户的 platform quota 状态。
// D14: 对每条记录逐档判断窗口过期,过期档位 usage=0、window_resets_at=null不写 DB
func (h *UserHandler) GetMyPlatformQuotas(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if h.userPlatformQuotaRepo == nil {
response.Success(c, map[string]any{"platform_quotas": []any{}})
return
}
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
now := time.Now().UTC()
out := make([]map[string]any, 0, len(records))
for _, r := range records {
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, false))
}
response.Success(c, map[string]any{"platform_quotas": out})
}
// ChangePasswordRequest represents the change password request payload
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`

View File

@ -87,8 +87,12 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
return 0, nil
}
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
return 0, nil
}
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
@ -144,7 +148,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
@ -202,7 +206,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -285,7 +289,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -364,7 +368,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -513,8 +517,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
@ -568,7 +572,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -627,8 +631,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -670,8 +674,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -714,8 +718,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
@ -752,7 +756,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()

View File

@ -0,0 +1,212 @@
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// fakeQuotaRepoForUserHandler 实现 service.UserPlatformQuotaRepository 最小子集
type fakeQuotaRepoForUserHandler struct {
service.UserPlatformQuotaRepository
records []service.UserPlatformQuotaRecord
}
func (f *fakeQuotaRepoForUserHandler) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
return f.records, nil
}
func TestGetMyPlatformQuotas_EmptyReturns200WithEmptyArray(t *testing.T) {
repo := &fakeQuotaRepoForUserHandler{records: nil}
h := &UserHandler{userPlatformQuotaRepo: repo}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
h.GetMyPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
}
var body struct {
Code int `json:"code"`
Data struct {
PlatformQuotas []any `json:"platform_quotas"`
} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal error: %v, body: %s", err, w.Body.String())
}
if body.Code != 0 {
t.Errorf("expected code=0, got %d", body.Code)
}
if body.Data.PlatformQuotas == nil {
// nil 和 empty slice 均视为可接受JSON 可能序列化为 null 或 []
// 此断言只验证 HTTP 200 + code=0 即可
}
}
func TestGetMyPlatformQuotas_D14_LazyZeroForExpiredWindow(t *testing.T) {
pastStart := time.Now().UTC().AddDate(0, 0, -2)
daily := 5.0
repo := &fakeQuotaRepoForUserHandler{records: []service.UserPlatformQuotaRecord{{
UserID: 42,
Platform: "anthropic",
DailyLimitUSD: &daily,
DailyUsageUSD: 3.0,
DailyWindowStart: &pastStart,
}}}
h := &UserHandler{userPlatformQuotaRepo: repo}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
h.GetMyPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
}
// 解析 response验证过期 daily 的 usage_usd=0 且 window_resets_at=null
body := w.Body.String()
if !strings.Contains(body, `"daily_usage_usd":0`) {
t.Errorf("expected daily_usage_usd:0 in body, got: %s", body)
}
if !strings.Contains(body, `"daily_window_resets_at":null`) {
t.Errorf("expected daily_window_resets_at:null in body, got: %s", body)
}
}
func TestGetMyPlatformQuotas_NilRepo_Returns200Empty(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: nil}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 99})
h.GetMyPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestGetMyPlatformQuotas_NoAuth_Returns401(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: nil}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
// 不设置 auth subject
h.GetMyPlatformQuotas(c)
if w.Code != 401 {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestLazyZeroQuotaForResponse_UserViewStripsWindowStart(t *testing.T) {
start := time.Now().UTC().Add(-1 * time.Hour)
r := service.UserPlatformQuotaRecord{
Platform: "anthropic",
DailyUsageUSD: 1.0,
DailyWindowStart: &start,
}
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), false)
if _, ok := out["daily_window_start"]; ok {
t.Error("user view should not include daily_window_start")
}
}
func TestLazyZeroQuotaForResponse_AdminViewIncludesWindowStart(t *testing.T) {
start := time.Now().UTC().Add(-1 * time.Hour)
r := service.UserPlatformQuotaRecord{
Platform: "anthropic",
DailyWindowStart: &start,
}
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), true)
if _, ok := out["daily_window_start"]; !ok {
t.Error("admin view should include daily_window_start")
}
}
func TestLazyZeroQuotaForResponse_ActiveWindowPreservesUsage(t *testing.T) {
// 今天的窗口起始时间(不过期):按全局时区取当天 0 点,与 view 层同口径
now := time.Now()
today := timezone.StartOfDay(now)
usage := 2.5
r := service.UserPlatformQuotaRecord{
Platform: "openai",
DailyUsageUSD: usage,
DailyWindowStart: &today,
}
out := quotaview.LazyZeroQuotaForResponse(r, now, false)
if out["daily_usage_usd"] != usage {
t.Errorf("expected daily_usage_usd=%v, got %v", usage, out["daily_usage_usd"])
}
// 活跃窗口应有 resets_at非 nil
if out["daily_window_resets_at"] == nil {
t.Error("active window should have daily_window_resets_at set")
}
}
func TestNeedsDailyReset_NilStart_ReturnsFalse(t *testing.T) {
if quotaview.NeedsDailyReset(nil, time.Now().UTC()) {
t.Error("nil start should not need reset")
}
}
func TestNeedsDailyReset_OldStart_ReturnsTrue(t *testing.T) {
old := time.Now().UTC().AddDate(0, 0, -1)
if !quotaview.NeedsDailyReset(&old, time.Now().UTC()) {
t.Error("yesterday start should need daily reset")
}
}
func TestNeedsWeeklyReset_NilStart_ReturnsFalse(t *testing.T) {
if quotaview.NeedsWeeklyReset(nil, time.Now().UTC()) {
t.Error("nil start should not need weekly reset")
}
}
func TestNeedsMonthlyReset_NilStart_ReturnsFalse(t *testing.T) {
if quotaview.NeedsMonthlyReset(nil, time.Now().UTC()) {
t.Error("nil start should not need monthly reset")
}
}
// TestNeedsMonthlyReset_30DayRolling 验证 30 天滚动语义C-NEW-1
func TestNeedsMonthlyReset_30DayRolling_Expired(t *testing.T) {
start := time.Now().UTC().Add(-31 * 24 * time.Hour) // 31 天前,已过期
if !quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
t.Error("31 days ago should need monthly reset (30-day rolling)")
}
}
func TestNeedsMonthlyReset_30DayRolling_Active(t *testing.T) {
start := time.Now().UTC().Add(-15 * 24 * time.Hour) // 15 天前,窗口有效
if quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
t.Error("15 days ago should NOT need monthly reset (30-day rolling, still active)")
}
}
// TestNeedsMonthlyReset_CrossMonthBoundary 验证跨自然月时 30 天未满不重置(旧自然月语义会提前重置)。
func TestNeedsMonthlyReset_CrossMonthBoundary(t *testing.T) {
// 窗口起始 4 月 20 日5 月 1 日仅过了 11 天,不足 30 天,不应重置
windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
if quotaview.NeedsMonthlyReset(&windowStart, now) {
t.Error("cross-month boundary within 30 days should NOT trigger reset (30-day rolling)")
}
}

View File

@ -517,6 +517,33 @@ func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestResponsesEventToAnthropicEvents_TopLevelTerminalUsage(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
},
Usage: &ResponsesUsage{
InputTokens: 20,
OutputTokens: 6,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 5,
},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
require.NotNil(t, events[0].Usage)
assert.Equal(t, 15, events[0].Usage.InputTokens)
assert.Equal(t, 5, events[0].Usage.CacheReadInputTokens)
assert.Equal(t, 6, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
}
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"

View File

@ -846,6 +846,33 @@ func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_TopLevelTerminalUsage(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
},
Usage: &ResponsesUsage{
InputTokens: 21,
OutputTokens: 9,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 4,
},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 21, chunks[1].Usage.PromptTokens)
assert.Equal(t, 9, chunks[1].Usage.CompletionTokens)
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
assert.Equal(t, 4, chunks[1].Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"

View File

@ -567,6 +567,12 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
events = append(events, closeCurrentBlock(state)...)
stopReason := "end_turn"
if evt.Usage != nil {
usage := anthropicUsageFromResponsesUsage(evt.Usage)
state.InputTokens = usage.InputTokens
state.OutputTokens = usage.OutputTokens
state.CacheReadInputTokens = usage.CacheReadInputTokens
}
if evt.Response != nil {
if evt.Response.Usage != nil {
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)

View File

@ -293,20 +293,12 @@ func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
state.Finalized = true
finishReason := "stop"
if evt.Usage != nil {
state.Usage = chatUsageFromResponsesUsage(evt.Usage)
}
if evt.Response != nil {
if evt.Response.Usage != nil {
u := evt.Response.Usage
usage := &ChatUsage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
state.Usage = usage
state.Usage = chatUsageFromResponsesUsage(evt.Response.Usage)
}
switch evt.Response.Status {
@ -340,6 +332,23 @@ func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
return chunks
}
func chatUsageFromResponsesUsage(u *ResponsesUsage) *ChatUsage {
if u == nil {
return nil
}
usage := &ChatUsage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
return usage
}
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
return ChatCompletionsChunk{
ID: state.ID,

View File

@ -380,6 +380,8 @@ type ResponsesStreamEvent struct {
// response.created / response.completed / response.done / response.failed / response.incomplete
Response *ResponsesResponse `json:"response,omitempty"`
// 部分 OpenAI 兼容上游会把 usage 放在终止事件顶层,而不是 response.usage。
Usage *ResponsesUsage `json:"usage,omitempty"`
// response.output_item.added / response.output_item.done
Item *ResponsesOutput `json:"item,omitempty"`

View File

@ -135,3 +135,29 @@ func TestDSTAwareness(t *testing.T) {
_ = Now()
_ = StartOfDay(Now())
}
func TestStartOfWeek_Boundaries(t *testing.T) {
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = Init("UTC") })
loc := Location()
wantMon := time.Date(2026, 5, 18, 0, 0, 0, 0, loc) // 2026-05-18 是周一
cases := []struct {
name string
in time.Time
}{
{"friday", time.Date(2026, 5, 22, 14, 30, 0, 0, loc)},
{"sunday", time.Date(2026, 5, 24, 10, 0, 0, 0, loc)},
{"monday-self", time.Date(2026, 5, 18, 9, 15, 30, 0, loc)},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if got := StartOfWeek(c.in); !got.Equal(wantMon) {
t.Errorf("StartOfWeek(%v) = %v, want %v", c.in, got, wantMon)
}
})
}
}

View File

@ -1085,7 +1085,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
if scope == "" {
return nil
}
@ -1094,6 +1094,11 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
if len(reason) > 0 {
if value := strings.TrimSpace(reason[0]); value != "" {
payload["reason"] = value
}
}
raw, err := json.Marshal(payload)
if err != nil {
return err
@ -1129,6 +1134,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}

View File

@ -183,6 +183,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig,
group.FieldModelsListConfig,
group.FieldRpmLimit,
)
}).
@ -723,6 +724,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
ModelsListConfig: g.ModelsListConfig,
RPMLimit: g.RpmLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,

View File

@ -328,3 +328,174 @@ func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int6
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}
// ============================================
// user × platform quota 缓存
// ============================================
// userPlatformQuotaCacheKey 构造 Redis key
func userPlatformQuotaCacheKey(userID int64, platform string) string {
return fmt.Sprintf("billing:user_platform_quota:%d:%s", userID, platform)
}
func (c *billingCache) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaCacheEntry, bool, error) {
key := userPlatformQuotaCacheKey(userID, platform)
fields := []string{
"daily_usage", "weekly_usage", "monthly_usage", "version", "schema_version",
"daily_limit", "weekly_limit", "monthly_limit",
"daily_window_start", "weekly_window_start", "monthly_window_start",
}
vals, err := c.rdb.HMGet(ctx, key, fields...).Result()
if err != nil {
return nil, false, err
}
// 前4个全为nil → key 不存在
if vals[0] == nil && vals[1] == nil && vals[2] == nil && vals[3] == nil {
return nil, false, nil
}
parseFloat := func(v any) float64 {
if v == nil {
return 0
}
s, ok := v.(string)
if !ok {
return 0
}
f, _ := strconv.ParseFloat(s, 64)
return f
}
parseFloatPtr := func(v any) *float64 {
if v == nil {
return nil
}
s, ok := v.(string)
if !ok || s == "" {
return nil
}
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil
}
return &f
}
parseTimePtr := func(v any) *time.Time {
if v == nil {
return nil
}
s, ok := v.(string)
if !ok || s == "" {
return nil
}
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil
}
t := time.Unix(n, 0).UTC()
return &t
}
parseInt64 := func(v any) int64 {
if v == nil {
return 0
}
s, ok := v.(string)
if !ok {
return 0
}
n, _ := strconv.ParseInt(s, 10, 64)
return n
}
return &service.UserPlatformQuotaCacheEntry{
DailyUsageUSD: parseFloat(vals[0]),
WeeklyUsageUSD: parseFloat(vals[1]),
MonthlyUsageUSD: parseFloat(vals[2]),
Version: parseInt64(vals[3]),
SchemaVersion: parseInt64(vals[4]),
DailyLimitUSD: parseFloatPtr(vals[5]),
WeeklyLimitUSD: parseFloatPtr(vals[6]),
MonthlyLimitUSD: parseFloatPtr(vals[7]),
DailyWindowStart: parseTimePtr(vals[8]),
WeeklyWindowStart: parseTimePtr(vals[9]),
MonthlyWindowStart: parseTimePtr(vals[10]),
}, true, nil
}
func (c *billingCache) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *service.UserPlatformQuotaCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
key := userPlatformQuotaCacheKey(userID, platform)
pipe := c.rdb.TxPipeline()
// 浮点可空字段nil → 空字符串(读取时 parseFloatPtr 返回 nil表示无限额
fmtFloatPtr := func(p *float64) string {
if p == nil {
return ""
}
return strconv.FormatFloat(*p, 'f', -1, 64)
}
// time.Time 可空字段nil → 空字符串;有值 → unix 秒
fmtTimePtr := func(p *time.Time) string {
if p == nil {
return ""
}
return strconv.FormatInt(p.Unix(), 10)
}
pipe.HSet(ctx, key,
"daily_usage", entry.DailyUsageUSD,
"weekly_usage", entry.WeeklyUsageUSD,
"monthly_usage", entry.MonthlyUsageUSD,
"version", entry.Version,
"schema_version", entry.SchemaVersion,
"daily_limit", fmtFloatPtr(entry.DailyLimitUSD),
"weekly_limit", fmtFloatPtr(entry.WeeklyLimitUSD),
"monthly_limit", fmtFloatPtr(entry.MonthlyLimitUSD),
"daily_window_start", fmtTimePtr(entry.DailyWindowStart),
"weekly_window_start", fmtTimePtr(entry.WeeklyWindowStart),
"monthly_window_start", fmtTimePtr(entry.MonthlyWindowStart),
)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
return c.rdb.Del(ctx, userPlatformQuotaCacheKey(userID, platform)).Err()
}
// updateUserPlatformQuotaUsageScript 缓存累加EXISTS + schema_version 双重守卫。
// 旧版 entryschema_version != ARGV[3],包括缺字段的 0 值)不参与累加,由上层走 DB fallback 后
// SetCache 重建为新版 entry —— 若此处仍累加,上层覆盖时会丢失这部分增量,导致 Redis usage 比真实偏小。
// key 不存在同样跳过(由下次 SetCache 重建)。
// KEYS[1] = hash key
// ARGV[1] = cost (string float)
// ARGV[2] = ttl seconds
// ARGV[3] = expected schema_version (Go 侧 UserPlatformQuotaCacheSchemaV1)
const updateUserPlatformQuotaUsageScript = `
if redis.call("EXISTS", KEYS[1]) == 0 then
return 0
end
local ver = redis.call("HGET", KEYS[1], "schema_version")
if ver == false or tonumber(ver) ~= tonumber(ARGV[3]) then
return 0
end
redis.call("HINCRBYFLOAT", KEYS[1], "daily_usage", ARGV[1])
redis.call("HINCRBYFLOAT", KEYS[1], "weekly_usage", ARGV[1])
redis.call("HINCRBYFLOAT", KEYS[1], "monthly_usage", ARGV[1])
redis.call("HINCRBY", KEYS[1], "version", 1)
redis.call("EXPIRE", KEYS[1], ARGV[2])
return 1
`
func (c *billingCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
key := userPlatformQuotaCacheKey(userID, platform)
_, err := c.rdb.Eval(ctx, updateUserPlatformQuotaUsageScript, []string{key},
strconv.FormatFloat(cost, 'f', -1, 64),
int(ttl.Seconds()),
service.UserPlatformQuotaCacheSchemaV1,
).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
return nil
}

View File

@ -0,0 +1,134 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func newMiniRedisCache(t *testing.T) (*billingCache, *miniredis.Miniredis) {
t.Helper()
mr := miniredis.RunT(t)
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
return &billingCache{rdb: rdb}, mr
}
func TestUserPlatformQuotaCache_GetMissReturnsNotFound(t *testing.T) {
c, _ := newMiniRedisCache(t)
entry, ok, err := c.GetUserPlatformQuotaCache(context.Background(), 1, "anthropic")
if err != nil {
t.Fatal(err)
}
if ok || entry != nil {
t.Errorf("expected miss, got ok=%v entry=%v", ok, entry)
}
}
func TestUserPlatformQuotaCache_SetThenGet(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
dailyLimit := 20.0
ts := time.Date(2024, 5, 1, 0, 0, 0, 0, time.UTC)
in := &service.UserPlatformQuotaCacheEntry{
DailyUsageUSD: 1.5,
WeeklyUsageUSD: 3.0,
MonthlyUsageUSD: 10.0,
Version: 7,
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
DailyLimitUSD: &dailyLimit,
DailyWindowStart: &ts,
}
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
t.Fatal(err)
}
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if err != nil || !ok {
t.Fatalf("get: ok=%v err=%v", ok, err)
}
if got.DailyUsageUSD != 1.5 || got.WeeklyUsageUSD != 3.0 || got.MonthlyUsageUSD != 10.0 || got.Version != 7 {
t.Errorf("got = %+v, want %+v", got, in)
}
if got.SchemaVersion != service.UserPlatformQuotaCacheSchemaV1 {
t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, service.UserPlatformQuotaCacheSchemaV1)
}
if got.DailyLimitUSD == nil || *got.DailyLimitUSD != dailyLimit {
t.Errorf("DailyLimitUSD = %v, want %v", got.DailyLimitUSD, dailyLimit)
}
if got.DailyWindowStart == nil || !got.DailyWindowStart.Equal(ts) {
t.Errorf("DailyWindowStart = %v, want %v", got.DailyWindowStart, ts)
}
}
func TestUserPlatformQuotaCache_NilLimitSetThenGet(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
in := &service.UserPlatformQuotaCacheEntry{
DailyUsageUSD: 1.0,
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
// DailyLimitUSD nil → 无限额
}
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
t.Fatal(err)
}
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if err != nil || !ok {
t.Fatalf("get: ok=%v err=%v", ok, err)
}
if got.DailyLimitUSD != nil {
t.Errorf("DailyLimitUSD should be nil for unlimited, got %v", got.DailyLimitUSD)
}
}
func TestUserPlatformQuotaCache_IncrMissIsNoop(t *testing.T) {
c, _ := newMiniRedisCache(t)
if err := c.IncrUserPlatformQuotaUsageCache(context.Background(), 1, "openai", 0.5, time.Minute); err != nil {
t.Fatal(err)
}
_, ok, _ := c.GetUserPlatformQuotaCache(context.Background(), 1, "openai")
if ok {
t.Error("expected key absent after no-op incr")
}
}
func TestUserPlatformQuotaCache_IncrHitAccumulates(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
// SchemaVersion 必须显式设为 V1,否则 Lua 脚本会因 schema 不匹配而 return 0,跳过累加。
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{
Version: 1,
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
}, time.Minute)
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.5, time.Minute); err != nil {
t.Fatal(err)
}
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.25, time.Minute); err != nil {
t.Fatal(err)
}
got, _, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if got.DailyUsageUSD != 0.75 || got.WeeklyUsageUSD != 0.75 || got.MonthlyUsageUSD != 0.75 {
t.Errorf("got %+v, want daily/weekly/monthly=0.75", got)
}
if got.Version != 3 {
t.Errorf("version = %d, want 3 (initial 1 + 2 incr)", got.Version)
}
}
func TestUserPlatformQuotaCache_Delete(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{Version: 1}, time.Minute)
if err := c.DeleteUserPlatformQuotaCache(ctx, 1, "openai"); err != nil {
t.Fatal(err)
}
_, ok, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if ok {
t.Error("expected miss after delete")
}
}

View File

@ -192,6 +192,7 @@ SELECT COUNT(*)
FROM content_moderation_logs
WHERE user_id = $1
AND flagged = TRUE
AND action <> 'hash_block'
AND created_at >= $2
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
`, userID, since).Scan(&count)
@ -246,7 +247,7 @@ func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) (
case "hit", "flagged":
where = append(where, "l.flagged = TRUE")
case "blocked", "block":
where = append(where, "l.action = 'block'")
where = append(where, "l.action IN ('block', 'keyword_block', 'hash_block')")
case "pass", "allow":
where = append(where, "l.flagged = FALSE AND l.error = ''")
case "error":

View File

@ -0,0 +1,40 @@
package repository
import (
"context"
"regexp"
"strings"
"testing"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestBuildContentModerationLogWhere_BlockedIncludesAllBlockActions(t *testing.T) {
where, args := buildContentModerationLogWhere(service.ContentModerationLogFilter{Result: "blocked"})
require.Empty(t, args)
sql := strings.Join(where, " AND ")
require.Contains(t, sql, "l.action IN ('block', 'keyword_block', 'hash_block')")
require.NotContains(t, sql, "l.action = 'block'")
}
func TestContentModerationRepositoryCountFlaggedByUserSince_ExcludesHashBlock(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
repo := NewContentModerationRepository(db)
since := time.Now().Add(-time.Hour)
mock.ExpectQuery(regexp.QuoteMeta("AND action <> 'hash_block'")).
WithArgs(int64(1001), since).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(2))
count, err := repo.CountFlaggedByUserSince(context.Background(), 1001, since)
require.NoError(t, err)
require.Equal(t, 2, count)
require.NoError(t, mock.ExpectationsWereMet())
}

View File

@ -66,6 +66,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetModelsListConfig(groupIn.ModelsListConfig).
SetRpmLimit(groupIn.RPMLimit)
// 设置模型路由配置
@ -141,6 +142,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
SetModelsListConfig(groupIn.ModelsListConfig).
SetRpmLimit(groupIn.RPMLimit)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。

View File

@ -3,6 +3,8 @@ package repository
import (
"compress/flate"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@ -50,6 +52,17 @@ const (
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值15分钟
defaultClientIdleTTLSeconds = 900
// OpenAI HTTP/2 代理回退策略默认值
defaultOpenAIHTTP2FallbackErrorThreshold = 2
defaultOpenAIHTTP2FallbackWindow = 60 * time.Second
defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute
)
const (
upstreamProtocolModeDefault = "default"
upstreamProtocolModeOpenAIH1 = "openai_h1"
upstreamProtocolModeOpenAIH2 = "openai_h2"
upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback"
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
@ -64,14 +77,30 @@ type poolSettings struct {
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
type openAIHTTP2Settings struct {
enabled bool
allowProxyFallbackToHTTP1 bool
fallbackErrorThreshold int
fallbackWindow time.Duration
fallbackTTL time.Duration
}
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
client *http.Client // HTTP 客户端实例
proxyKey string // 代理标识(用于检测代理变更)
poolKey string // 连接池配置标识(用于检测配置变更)
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
client *http.Client // HTTP 客户端实例
proxyKey string // 代理标识(用于检测代理变更)
poolKey string // 连接池配置标识(用于检测配置变更)
protocolMode string // 协议模式default/openai_h1/openai_h2/openai_h1_fallback
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
}
type openAIHTTP2FallbackState struct {
mu sync.Mutex
windowStart time.Time
errorCount int
fallbackUntil time.Time
}
// httpUpstreamService 通用 HTTP 上游服务
@ -95,6 +124,8 @@ type httpUpstreamService struct {
cfg *config.Config // 全局配置
mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池key 由隔离策略决定
// OpenAI 走 HTTP/HTTPS 代理时的 H2->H1 回退状态key=标准化 proxyKey
openAIHTTP2Fallbacks sync.Map
}
// NewHTTPUpstream 创建通用 HTTP 上游服务
@ -142,9 +173,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
profile := service.HTTPUpstreamProfileDefault
if req != nil {
profile = service.HTTPUpstreamProfileFromContext(req.Context())
}
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile)
if err != nil {
return nil, err
}
@ -152,11 +187,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err)
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey)
// 如果上游返回了压缩内容,解压后再交给业务层
decompressResponseBody(resp)
@ -179,6 +216,10 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
if profile == nil {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
upstreamProfile := service.HTTPUpstreamProfileDefault
if req != nil {
upstreamProfile = service.HTTPUpstreamProfileFromContext(req.Context())
}
targetHost := ""
if req != nil && req.URL != nil {
@ -194,7 +235,7 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
return nil, err
}
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile, upstreamProfile)
if err != nil {
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
return nil, err
@ -219,21 +260,23 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
}
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, upstreamProfile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) {
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, upstreamProfile, true, true)
}
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, upstreamProfile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
isolation := s.getIsolationMode()
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
if err != nil {
return nil, err
}
settings := s.resolvePoolSettings(isolation, accountConcurrency)
settings = s.applyProfilePoolSettings(settings, upstreamProfile)
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault)
poolKey := buildPoolKey(settings, upstreamProtocolModeDefault) + ":tls"
now := time.Now()
nowUnix := now.UnixNano()
@ -284,7 +327,6 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i
// 创建带 TLS 指纹的 Transport
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", logredact.RedactProxyURL(proxyKey))
settings := s.resolvePoolSettings(isolation, accountConcurrency)
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
if err != nil {
s.mu.Unlock()
@ -350,7 +392,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault)
}
// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。
func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true)
}
// getOrCreateClient 获取或创建客户端
@ -369,13 +416,13 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false)
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
@ -383,10 +430,14 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
if err != nil {
return nil, err
}
// 根据请求 profile例如 OpenAI选择协议模式
protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy)
settings := s.resolvePoolSettings(isolation, accountConcurrency)
settings = s.applyProfilePoolSettings(settings, profile)
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode)
// 构建连接池配置键(用于检测配置变更)
poolKey := s.buildPoolKey(isolation, accountConcurrency)
poolKey := buildPoolKey(settings, protocolMode)
now := time.Now()
nowUnix := now.UnixNano()
@ -429,8 +480,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
}
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
transport, err := buildUpstreamTransport(settings, parsedProxy)
transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode)
if err != nil {
s.mu.Unlock()
return nil, fmt.Errorf("build transport: %w", err)
@ -440,9 +490,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
protocolMode: protocolMode,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
@ -626,22 +677,31 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu
return settings
}
// buildPoolKey 构建连接池配置键
// 用于检测配置变更,配置变更时需要重建客户端
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - string: 配置键
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
return fmt.Sprintf("account:%d", accountConcurrency)
}
func (s *httpUpstreamService) applyProfilePoolSettings(settings poolSettings, profile service.HTTPUpstreamProfile) poolSettings {
if profile != service.HTTPUpstreamProfileOpenAI {
return settings
}
return "default"
settings.responseHeaderTimeout = 0
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIResponseHeaderTimeout > 0 {
settings.responseHeaderTimeout = time.Duration(s.cfg.Gateway.OpenAIResponseHeaderTimeout) * time.Second
}
return settings
}
// buildPoolKey 构建连接池配置键,用于检测连接池配置变更。
func buildPoolKey(settings poolSettings, protocolMode string) string {
base := fmt.Sprintf(
"idle:%d|idle_host:%d|max:%d|idle_timeout:%s|header_timeout:%s",
settings.maxIdleConns,
settings.maxIdleConnsPerHost,
settings.maxConnsPerHost,
settings.idleConnTimeout,
settings.responseHeaderTimeout,
)
if protocolMode == "" || protocolMode == upstreamProtocolModeDefault {
return base
}
return base + "|proto:" + protocolMode
}
// buildCacheKey 构建客户端缓存键
@ -659,15 +719,245 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string {
var base string
switch isolation {
case config.ConnectionPoolIsolationAccount:
return fmt.Sprintf("account:%d", accountID)
base = fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
return fmt.Sprintf("proxy:%s", proxyKey)
base = fmt.Sprintf("proxy:%s", proxyKey)
}
if protocolMode != "" && protocolMode != upstreamProtocolModeDefault {
base += "|proto:" + protocolMode
}
return base
}
func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings {
settings := openAIHTTP2Settings{
enabled: false,
allowProxyFallbackToHTTP1: true,
fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold,
fallbackWindow: defaultOpenAIHTTP2FallbackWindow,
fallbackTTL: defaultOpenAIHTTP2FallbackTTL,
}
if s == nil || s.cfg == nil {
return settings
}
cfg := s.cfg.Gateway.OpenAIHTTP2
settings.enabled = cfg.Enabled
settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1
if cfg.FallbackErrorThreshold > 0 {
settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold
}
if cfg.FallbackWindowSeconds > 0 {
settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second
}
if cfg.FallbackTTLSeconds > 0 {
settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second
}
return settings
}
func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string {
if profile != service.HTTPUpstreamProfileOpenAI {
return upstreamProtocolModeDefault
}
settings := s.resolveOpenAIHTTP2Settings()
if !settings.enabled {
return upstreamProtocolModeOpenAIH1
}
if parsedProxy == nil {
return upstreamProtocolModeOpenAIH2
}
scheme := strings.ToLower(parsedProxy.Scheme)
if scheme != "http" && scheme != "https" {
return upstreamProtocolModeOpenAIH2
}
if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) {
return upstreamProtocolModeOpenAIH1Fallback
}
return upstreamProtocolModeOpenAIH2
}
func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool {
raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
if !ok {
return false
}
state, ok := raw.(*openAIHTTP2FallbackState)
if !ok || state == nil {
return false
}
return state.isFallbackActive(time.Now())
}
func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState {
state := &openAIHTTP2FallbackState{}
actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state)
cached, ok := actual.(*openAIHTTP2FallbackState)
if !ok || cached == nil {
return state
}
return cached
}
func isHTTPProxyKey(proxyKey string) bool {
return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://")
}
func isOpenAIHTTP2CompatibilityError(err error) bool {
if err == nil {
return false
}
if isUpstreamTimeoutError(err) {
return false
}
msg := strings.ToLower(err.Error())
if msg == "" {
return false
}
markers := []string{
"alpn",
"no application protocol",
"protocol error",
"stream error",
"goaway",
"refused_stream",
"frame too large",
}
for _, marker := range markers {
if strings.Contains(msg, marker) {
return true
}
}
return false
}
func isUpstreamTimeoutError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
msg := strings.ToLower(err.Error())
if msg == "" {
return false
}
timeoutMarkers := []string{
"timeout awaiting response headers",
"i/o timeout",
"context deadline exceeded",
"client.timeout exceeded while awaiting headers",
"tls handshake timeout",
}
for _, marker := range timeoutMarkers {
if strings.Contains(msg, marker) {
return true
}
}
return false
}
func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) {
if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
return
}
settings := s.resolveOpenAIHTTP2Settings()
if !settings.enabled || !settings.allowProxyFallbackToHTTP1 {
return
}
if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) {
return
}
state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL)
if activated {
slog.Warn("openai_http2_proxy_fallback_activated",
"proxy", proxyKey,
"fallback_until", until.Format(time.RFC3339))
}
}
func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) {
if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
return
}
if !isHTTPProxyKey(proxyKey) {
return
}
raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
if !ok {
return
}
state, ok := raw.(*openAIHTTP2FallbackState)
if !ok || state == nil {
return
}
state.resetErrorWindow()
}
func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.fallbackUntil.IsZero() {
return false
}
if now.Before(s.fallbackUntil) {
return true
}
s.fallbackUntil = time.Time{}
return false
}
func (s *openAIHTTP2FallbackState) resetErrorWindow() {
s.mu.Lock()
defer s.mu.Unlock()
s.windowStart = time.Time{}
s.errorCount = 0
}
func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) {
if threshold <= 0 {
threshold = defaultOpenAIHTTP2FallbackErrorThreshold
}
if window <= 0 {
window = defaultOpenAIHTTP2FallbackWindow
}
if ttl <= 0 {
ttl = defaultOpenAIHTTP2FallbackTTL
}
s.mu.Lock()
defer s.mu.Unlock()
if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) {
return false, s.fallbackUntil
}
if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) {
s.fallbackUntil = time.Time{}
}
if s.windowStart.IsZero() || now.Sub(s.windowStart) > window {
s.windowStart = now
s.errorCount = 0
}
s.errorCount++
if s.errorCount < threshold {
return false, time.Time{}
}
s.fallbackUntil = now.Add(ttl)
s.windowStart = time.Time{}
s.errorCount = 0
return true, s.fallbackUntil
}
// normalizeProxyURL 标准化代理 URL
@ -739,7 +1029,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
}
if cfg.Gateway.ResponseHeaderTimeout > 0 {
if cfg.Gateway.ResponseHeaderTimeout >= 0 {
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
}
}
@ -770,7 +1060,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
@ -778,6 +1068,17 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
switch protocolMode {
case upstreamProtocolModeOpenAIH2:
transport.ForceAttemptHTTP2 = true
case upstreamProtocolModeOpenAIH1:
transport.ForceAttemptHTTP2 = false
transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
case upstreamProtocolModeOpenAIH1Fallback:
// 显式禁用 HTTP/2确保代理不兼容场景回退到 HTTP/1.1。
transport.ForceAttemptHTTP2 = false
transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
}
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
return nil, err
}

Some files were not shown because too many files have changed in this diff Show More