sub2api/backend/internal/repository/user_platform_quota_repo.go
DaydreamCoding 6b39b344d8 feat(quota): 用户 × 平台 USD 配额
为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月
三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。

两层模型:
- 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/
  dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写
  (系统 1 个 key + 每个来源 1 个 key),整体替换语义。
- 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。

后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository
与 service 端口、计费链路集成、管理端与用户端读写接口。
前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、
中英文案。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-26 10:49:20 +08:00

417 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
// UserPlatformQuotaRecord 是 repository 层的传输结构体,
// 与 ent.UserPlatformQuota 实体解耦,供业务层使用。
type UserPlatformQuotaRecord struct {
UserID int64
Platform string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
DailyUsageUSD float64
WeeklyUsageUSD float64
MonthlyUsageUSD float64
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
}
// ErrUserPlatformQuotaNotFound 用于 ResetExpiredWindow 等需要"必须命中已有记录"的方法。
var ErrUserPlatformQuotaNotFound = fmt.Errorf("user platform quota record not found")
// UserPlatformQuotaRepository 定义用户平台配额的数据访问接口。
type UserPlatformQuotaRepository interface {
// BulkInsertInitial 幂等批量插入初始配额记录ON CONFLICT DO NOTHING
BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error
// GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。
GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error)
// ListByUser 查询用户的所有平台配额记录(排除软删除)。
ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error)
// IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。
IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error
// ResetExpiredWindow 重置指定窗口daily/weekly/monthly的用量与起始时间。
ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error
// UpsertForUser 全量替换该用户所有平台限额配置(详见 service.UserPlatformQuotaRepository.UpsertForUser
UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error
}
type userPlatformQuotaRepository struct {
client *dbent.Client
}
// NewUserPlatformQuotaRepository 创建 UserPlatformQuotaRepository 实现。
func NewUserPlatformQuotaRepository(client *dbent.Client) UserPlatformQuotaRepository {
return &userPlatformQuotaRepository{client: client}
}
// BulkInsertInitial 用原生 SQL ON CONFLICT 实现幂等批量插入(带条件 limit 覆盖)。
// 仅插入 limit_usd 与元数据usage_usd 用 DB 默认 0window_start 留 NULL。
// FK 约束要求 user_id 在 users 表中存在,调用方负责保证。
//
// 冲突策略CASE WHEN existing.*_limit_usd IS NULL THEN EXCLUDED.*_limit_usd ELSE existing ...
// - 若 IncrementUsageWithReset 因时序问题已先建行limit 全 NULL
// 此处会把注册时的默认 limit 写入,避免该用户在该平台永久无限额。
// - 若管理员已通过 UpsertForUser 设置了非 NULL 个性化 limit**保留不动**
// —— 旧实现无条件 EXCLUDED 覆盖会丢失个性化配置。
// - 不会改 usage_usd / window_start保留累计的用量。
// - 仅命中 deleted_at IS NULL 的活跃记录partial unique index 作用域)。
func (r *userPlatformQuotaRepository) BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error {
if len(records) == 0 {
return nil
}
client := clientFromContext(ctx, r.client)
var sb strings.Builder
_, _ = sb.WriteString("INSERT INTO user_platform_quotas (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) VALUES ")
args := make([]any, 0, len(records)*6)
// 统一时间戳:避免循环内多次 time.Now() 让同一批记录的 created_at/updated_at
// 出现亚毫秒级偏差(与 UpsertForUser 的 now := time.Now() 风格一致)。
now := time.Now()
for i, rec := range records {
base := i * 6
if i > 0 {
_, _ = sb.WriteString(",")
}
fmt.Fprintf(&sb, "($%d,$%d,$%d,$%d,$%d,0,0,0,$%d,$%d)",
base+1, base+2, base+3, base+4, base+5, base+6, base+6)
args = append(args,
rec.UserID, rec.Platform,
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD,
now,
)
}
// 精确命中 partial unique indexdeleted_at IS NULL避免对软删记录的歧义冲突。
// 条件覆盖:仅在现有 limit 为 NULL 时才写入 EXCLUDED否则保留现有非 NULL 值。
// - 修复 IncrementUsageWithReset 已用 NULL limit 建行的场景NULL → 注册默认)
// - 保护管理员通过 UpsertForUser 设置的个性化 limit 不被静默覆盖
_, _ = sb.WriteString(` ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL
DO UPDATE SET
daily_limit_usd = COALESCE(user_platform_quotas.daily_limit_usd, EXCLUDED.daily_limit_usd),
weekly_limit_usd = COALESCE(user_platform_quotas.weekly_limit_usd, EXCLUDED.weekly_limit_usd),
monthly_limit_usd = COALESCE(user_platform_quotas.monthly_limit_usd, EXCLUDED.monthly_limit_usd),
updated_at = EXCLUDED.updated_at`)
_, err := client.ExecContext(ctx, sb.String(), args...)
return err
}
// GetByUserPlatform 通过 ent 查询单条配额(排除软删除)。未找到返回 (nil, nil)。
func (r *userPlatformQuotaRepository) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) {
client := clientFromContext(ctx, r.client)
entity, err := client.UserPlatformQuota.Query().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.PlatformEQ(platform),
userplatformquota.DeletedAtIsNil(),
).
Only(ctx)
if dbent.IsNotFound(err) {
return nil, nil
}
if err != nil {
return nil, err
}
return entQuotaToRecord(entity), nil
}
// ListByUser 查询用户的所有平台配额记录(排除软删除)。
func (r *userPlatformQuotaRepository) ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) {
client := clientFromContext(ctx, r.client)
rows, err := client.UserPlatformQuota.Query().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.DeletedAtIsNil(),
).
All(ctx)
if err != nil {
return nil, err
}
out := make([]UserPlatformQuotaRecord, 0, len(rows))
for _, e := range rows {
out = append(out, *entQuotaToRecord(e))
}
return out, nil
}
// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的 *_usage_usd。
// 行为:
// - 若记录存在:在事务内 SELECT FOR UPDATE按 (prev_window_start vs current_window_start)
// 判断是否需要重置(不同 = 重置为 cost相同 = 累加 cost
// - 若记录不存在fail-open create 分支):插入新记录,**limit 字段保留 nil无限制**
// —— 这是预期行为billing 链路不能因 quota 表缺失而阻断请求,未注册路径
// 的用户 quota 默认放行,由调度层指标观测 + 后台对账补建 limit
//
// 上层正常路径(注册时 BulkInsertInitial保证 limit 在记录创建时就被写入。
func (r *userPlatformQuotaRepository) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
existing, err := txClient.UserPlatformQuota.Query().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.PlatformEQ(platform),
userplatformquota.DeletedAtIsNil(),
).
ForUpdate().
Only(txCtx)
if dbent.IsNotFound(err) {
// fail-open 建行limit_* 保留 NULL无限额
// 用 ON CONFLICT DO UPDATE 累加,而非裸 INSERT并发下另一请求可能在本事务
// SELECT FOR UPDATE 之后、INSERT 之前刚建行,裸 INSERT 会撞 partial unique index
// 致事务回滚、本次 cost 丢失DO UPDATE 把 cost 累加到既有 usage 上。
// 写法与本文件 insertLimitsRow / BulkInsertInitial 的 ON CONFLICT 一致。
const insertSQL = `INSERT INTO user_platform_quotas
(user_id, platform, daily_usage_usd, weekly_usage_usd, monthly_usage_usd,
daily_window_start, weekly_window_start, monthly_window_start, created_at, updated_at)
VALUES ($1, $2, $3, $3, $3, $4, $5, $6, $7, $7)
ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO UPDATE SET
daily_usage_usd = user_platform_quotas.daily_usage_usd + EXCLUDED.daily_usage_usd,
weekly_usage_usd = user_platform_quotas.weekly_usage_usd + EXCLUDED.weekly_usage_usd,
monthly_usage_usd = user_platform_quotas.monthly_usage_usd + EXCLUDED.monthly_usage_usd,
updated_at = EXCLUDED.updated_at`
// $6 = now30 天滚动月度窗口以当前时刻为起始
_, e := txClient.ExecContext(txCtx, insertSQL,
userID, platform, cost,
timezone.StartOfDay(now), timezone.StartOfWeek(now), now, now)
return e
}
if err != nil {
return err
}
newDaily := maybeReset(existing.DailyUsageUsd, existing.DailyWindowStart, timezone.StartOfDay(now), cost)
newWeekly := maybeReset(existing.WeeklyUsageUsd, existing.WeeklyWindowStart, timezone.StartOfWeek(now), cost)
// 30 天滚动月度窗口:过期时重置为 cost 并以 now 为新起始,否则累加保留原起始
newMonthly, newMonthlyStart := monthlyMaybeReset(existing.MonthlyUsageUsd, existing.MonthlyWindowStart, cost, now)
_, e := existing.Update().
SetDailyUsageUsd(newDaily).
SetWeeklyUsageUsd(newWeekly).
SetMonthlyUsageUsd(newMonthly).
SetDailyWindowStart(timezone.StartOfDay(now)).
SetWeeklyWindowStart(timezone.StartOfWeek(now)).
SetMonthlyWindowStart(newMonthlyStart). // 30 天滚动:仅过期时更新起始
Save(txCtx)
return e
})
}
// ResetExpiredWindow 无条件重置指定窗口daily/weekly/monthly的用量与起始时间。
//
// ⚠️ 命名警告NOT a "check-then-reset" helper
//
// 名字里的 "Expired" 是历史遗留,**实现并不校验窗口是否真的过期**。
// 任何调用都会无条件把对应窗口的 *_usage_usd 清零并重写 *_window_start。
// 目前唯一合法 caller 是 admin POST /reset 接口(管理员强制归零)。
//
// 如果你想要"仅在窗口过期才重置"的语义,请直接使用 IncrementUsageWithReset
// 的内部判断maybeReset / monthlyMaybeReset或新增独立函数
// 不要复用这里的函数,否则会出现"明明窗口未过期,用量却被清零"的隐蔽 bug。
//
// 未命中活跃记录时返回 ErrUserPlatformQuotaNotFound。
func (r *userPlatformQuotaRepository) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
client := clientFromContext(ctx, r.client)
upd := client.UserPlatformQuota.Update().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.PlatformEQ(platform),
userplatformquota.DeletedAtIsNil(),
)
switch window {
case "daily":
upd = upd.SetDailyUsageUsd(0).SetDailyWindowStart(newStart)
case "weekly":
upd = upd.SetWeeklyUsageUsd(0).SetWeeklyWindowStart(newStart)
case "monthly":
upd = upd.SetMonthlyUsageUsd(0).SetMonthlyWindowStart(newStart)
default:
return fmt.Errorf("unknown window %q", window)
}
n, err := upd.Save(ctx)
if err != nil {
return err
}
if n == 0 {
return ErrUserPlatformQuotaNotFound
}
return nil
}
// withTx 在事务中执行 fn若 ctx 中已有事务则复用。
func (r *userPlatformQuotaRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client())
}
tx, err := r.client.Tx(ctx)
if err != nil {
return fmt.Errorf("begin user_platform_quota transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx, tx.Client()); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit user_platform_quota transaction: %w", err)
}
return nil
}
// entQuotaToRecord 将 ent entity 映射为 repository record。
// 注意 ent 生成字段名为 DailyLimitUsd非 DailyLimitUSD
func entQuotaToRecord(e *dbent.UserPlatformQuota) *UserPlatformQuotaRecord {
return &UserPlatformQuotaRecord{
UserID: e.UserID,
Platform: e.Platform,
DailyLimitUSD: e.DailyLimitUsd,
WeeklyLimitUSD: e.WeeklyLimitUsd,
MonthlyLimitUSD: e.MonthlyLimitUsd,
DailyUsageUSD: e.DailyUsageUsd,
WeeklyUsageUSD: e.WeeklyUsageUsd,
MonthlyUsageUSD: e.MonthlyUsageUsd,
DailyWindowStart: e.DailyWindowStart,
WeeklyWindowStart: e.WeeklyWindowStart,
MonthlyWindowStart: e.MonthlyWindowStart,
}
}
// maybeReset 判断是否需要重置窗口用量:
// - 若 prevStart 为 nil 或与 currStart 不同,表示窗口已过期,返回 cost重置
// - 否则返回 prevUsage + cost累加
func maybeReset(prevUsage float64, prevStart *time.Time, currStart time.Time, cost float64) float64 {
if prevStart == nil || !prevStart.Equal(currStart) {
return cost
}
return prevUsage + cost
}
// monthlyMaybeReset 判断 30 天滚动月度窗口是否需要重置。
// 过期条件prevStart 为 nil 或 now - prevStart >= 30×24h与订阅模式 NeedsMonthlyReset 语义一致)。
// 过期时重置为 cost否则累加。返回 (newUsage, newWindowStart)。
func monthlyMaybeReset(prevUsage float64, prevStart *time.Time, cost float64, now time.Time) (float64, time.Time) {
if prevStart == nil || now.Sub(*prevStart) >= 30*24*time.Hour {
return cost, now
}
return prevUsage + cost, *prevStart
}
// UpsertForUser 全量替换该用户的所有平台限额(事务内):
// 1. 软删除未在 records 中出现的所有 active 行
// 2. 对每条 record 尝试 UPDATE含 deleted_at = NULL 兼容重激活);
// UPDATE 行数为 0 时 INSERT 新行
//
// 仅改 *_limit_usd + deleted_at + updated_at保留 *_usage_usd / *_window_start。
func (r *userPlatformQuotaRepository) UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error {
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
platforms := make([]string, 0, len(records))
for _, rec := range records {
platforms = append(platforms, rec.Platform)
}
now := time.Now()
if err := softDeleteMissingPlatforms(txCtx, txClient, userID, platforms, now); err != nil {
return err
}
for _, rec := range records {
affected, err := updateLimitsRow(txCtx, txClient, userID, rec, now)
if err != nil {
return err
}
if affected == 0 {
if err := insertLimitsRow(txCtx, txClient, userID, rec, now); err != nil {
return err
}
}
}
return nil
})
}
// softDeleteMissingPlatforms 软删除该用户所有不在 keepPlatforms 中的 active 行。
// keepPlatforms 为空时 → 软删用户所有 active 行。
// now 由调用方传入,与 updateLimitsRow / insertLimitsRow 共享同一个 Go time.Now()
// 保证事务内所有时间戳一致(避免 Postgres NOW() 与 Go time.Now() 的微小偏差)。
func softDeleteMissingPlatforms(ctx context.Context, client *dbent.Client, userID int64, keepPlatforms []string, now time.Time) error {
var (
query string
args []any
)
if len(keepPlatforms) == 0 {
query = `UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2
WHERE user_id = $1 AND deleted_at IS NULL`
args = []any{userID, now}
} else {
placeholders := make([]string, len(keepPlatforms))
args = make([]any, 0, len(keepPlatforms)+2)
args = append(args, userID, now)
for i, p := range keepPlatforms {
placeholders[i] = fmt.Sprintf("$%d", i+3)
args = append(args, p)
}
query = fmt.Sprintf(`UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2
WHERE user_id = $1 AND deleted_at IS NULL AND platform NOT IN (%s)`,
strings.Join(placeholders, ","))
}
_, err := client.ExecContext(ctx, query, args...)
return err
}
// updateLimitsRow 尝试 UPDATE active 行deleted_at IS NULL返回受影响行数。
// 仅更新 active 行:若存在多条历史软删记录,加 deleted_at IS NULL 守卫可避免
// 批量重激活导致的 partial unique indexuserplatformquota_user_id_platform_uq冲突。
// affected=0 时由调用方 UpsertForUser 走 insertLimitsRow 路径创建新行。
func updateLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) (int64, error) {
const query = `UPDATE user_platform_quotas
SET daily_limit_usd = $1, weekly_limit_usd = $2, monthly_limit_usd = $3,
deleted_at = NULL, updated_at = $4
WHERE user_id = $5 AND platform = $6 AND deleted_at IS NULL`
res, err := client.ExecContext(ctx, query,
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, now,
userID, rec.Platform)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// insertLimitsRow 插入新限额行usage 默认 0window_start 默认 NULL
// 带 ON CONFLICT ... DO NOTHING 守卫:防止两个并发请求同时为同一 user/platform 新建行时
// 触发 unique constraint 违反userplatformquota_user_id_platform_uq 部分唯一索引)。
// affected=0 时说明另一个并发请求刚完成 INSERTfallback 到 updateLimitsRow 覆写 limits 值。
func insertLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) error {
const query = `INSERT INTO user_platform_quotas
(user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd,
daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, 0, 0, 0, $6, $6)
ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO NOTHING`
res, err := client.ExecContext(ctx, query,
userID, rec.Platform,
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD,
now)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
// 并发情形另一请求已插入该行fallback 到 UPDATE 覆写 limits 值last-writer-wins
_, err = updateLimitsRow(ctx, client, userID, rec, now)
return err
}
return nil
}