sub2api/backend/internal/handler/gateway_handler_billing_error_test.go
DaydreamCoding 6b39b344d8 feat(quota): 用户 × 平台 USD 配额
为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月
三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。

两层模型:
- 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/
  dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写
  (系统 1 个 key + 每个来源 1 个 key),整体替换语义。
- 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。

后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository
与 service 端口、计费链路集成、管理端与用户端读写接口。
前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、
中英文案。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-26 10:49:20 +08:00

129 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
require.Equal(t, http.StatusTooManyRequests, status)
require.Equal(t, "rate_limit_exceeded", code)
require.NotEmpty(t, msg)
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
require.LessOrEqual(t, retryAfter, 60)
}
func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
require.Equal(t, http.StatusTooManyRequests, status)
require.Equal(t, "rate_limit_exceeded", code)
require.NotEmpty(t, msg)
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
require.LessOrEqual(t, retryAfter, 60)
}
func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
for _, err := range []error{
service.ErrAPIKeyRateLimit5hExceeded,
service.ErrAPIKeyRateLimit1dExceeded,
service.ErrAPIKeyRateLimit7dExceeded,
} {
status, code, _, _ := billingErrorDetails(err)
require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
require.Equal(t, "rate_limit_exceeded", code)
}
}
func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
require.Equal(t, http.StatusServiceUnavailable, status)
require.Equal(t, "billing_service_error", code)
require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
}
func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
require.Equal(t, http.StatusForbidden, status)
require.Equal(t, "billing_error", code)
require.NotEmpty(t, msg)
}
func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) {
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339),
})
got := extractQuotaResetSeconds(err)
if got < 10 || got > 11 {
t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got)
}
}
func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) {
if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 {
t.Errorf("T20: got %d, want 60 fallback", got)
}
}
func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) {
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": "not-a-time",
})
if got := extractQuotaResetSeconds(err); got != 60 {
t.Errorf("T21: got %d, want 60 fallback", got)
}
}
func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) {
// 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s
// 1 秒会导致客户端立即重试仍触发限额的退避循环;
// 60s 让客户端按常规节奏退避cache/DB 自愈期间不会反复打抖。
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339),
})
if got := extractQuotaResetSeconds(err); got != 60 {
t.Errorf("T22: got %d, want 60 (fallback on past reset)", got)
}
}
func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) {
// quota 超限映射 429 + Retry-AfterRFC 6585 / 与 RPM 一致),
// 让 SDKOpenAI 兼容客户端等)能按 Retry-After 自动退避。
// 旧实现用 403 导致客户端不退避直接报错。
// 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。
cases := []struct {
name string
err error
}{
{"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
{"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
{"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
status, code, _, retryAfter := billingErrorDetails(tc.err)
if status != http.StatusTooManyRequests {
t.Errorf("status = %d, want 429", status)
}
if code != "rate_limit_exceeded" {
t.Errorf("code = %q, want rate_limit_exceeded", code)
}
if retryAfter < 3599 || retryAfter > 3601 {
t.Errorf("retryAfter = %d, want ~3600", retryAfter)
}
})
}
}