chore: merge upstream Wei-Shaw/sub2api v0.1.128 — keep fork customizations

Upstream 新功能 (34 commits, ~main..origin/main):
- feat(email): 通知邮件模板服务、模板编辑器、订阅/余额提醒邮件
- feat(notification): NotificationEmailService 注入到 Balance/Payment/Setting
- feat(payment): 支付成功通知邮件
- feat(usage): 用户 API Key 用量页支持按日明细
- feat(openai-gateway): Codex OAuth 浏览器 UA 自动改写规避 Cloudflare 质询
- feat(admin): 邮件模板管理接口
- fix(auth): 停用/删除分组后阻断 API Key
- fix(group): 修正分组账号可用计数口径
- fix(openai): /v1/responses respect force chat completions, images n 参数透传
- test(repository): AES Encryptor 单元测试
- chore: VERSION 0.1.128

冲突解决 (backend/cmd/server/wire_gen.go):
- 引入 upstream 新 wire providers: notificationEmailService,
  ProvidePaymentService(10 args), ProvideAdminSettingHandler(8 args)
- 保留 fork 独有依赖: rpmTokenBucketService (RPM 平滑),
  NewOpsHandler 3 参数版本 (requestEventBus, opsLogBroadcaster)
- ProvideBalanceNotifyService 接受 4 参数 (含 notificationEmailService)

修复 session-id helper 设计 (claude_code_session_id.go):
- 发现: TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata
  在 OAuth + mimicClaudeCode=false 场景失败
- 重新审视设计原则: OAuth 凭证本身就是 Claude Code 客户端,可信任 metadata
  派生 session_id;不应受 mimicClaudeCode 标志阻止
- 修复: metadata 派生只看 tokenType=="oauth";UUID 兜底仍需 oauth && mimic
- 更新测试: OAuthNonMimicDerivesFromMetadata 取代原 IgnoresMetadata

所有 fork 独有功能保留:
- Claude Code 2.1.145 mimicry bundle (上个 commit 引入)
- RPM token bucket smoothing (commit 95814974)
- Windsurf/Antigravity/Omniroute 定制
- claudemask/ 校验包 (upstream 已删除,我们仍在 gateway_service 中使用)

不在范围:
- 不修复 baseline 既存的 2 个测试失败 (TestProxyImportData...,
  TestWindsurfTierAccessService_Snapshot_HappyPath) - 与 merge 无关
This commit is contained in:
win 2026-05-20 17:50:44 +08:00
commit 92433656f5
71 changed files with 6395 additions and 166 deletions

View File

@ -1 +1 @@
0.1.127
0.1.128

View File

@ -189,7 +189,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
rpmTokenBucketService := service.NewRPMTokenBucketService()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
@ -204,8 +205,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService, notificationEmailService)
settingHandler := handler.ProvideAdminSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService, notificationEmailService)
requestEventBus := service.NewRequestEventBus()
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
@ -253,7 +254,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo, notificationEmailService)
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
@ -274,7 +275,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)

View File

@ -56,13 +56,14 @@ func firstNonEmpty(values ...string) string {
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
paymentConfigService *service.PaymentConfigService
paymentService *service.PaymentService
userAttributeService *service.UserAttributeService
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
paymentConfigService *service.PaymentConfigService
paymentService *service.PaymentService
userAttributeService *service.UserAttributeService
notificationEmailService *service.NotificationEmailService
}
// NewSettingHandler 创建系统设置处理器
@ -78,6 +79,12 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
}
}
// SetNotificationEmailService attaches the notification template service without changing
// the constructor signature used by existing unit tests.
func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) {
h.notificationEmailService = notificationEmailService
}
// GetSettings 获取所有系统设置
// GET /api/v1/admin/settings
func (h *SettingHandler) GetSettings(c *gin.Context) {
@ -247,6 +254,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
OpenAICodexUserAgent: settings.OpenAICodexUserAgent,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
@ -563,6 +571,7 @@ type UpdateSettingsRequest struct {
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
OpenAICodexUserAgent *string `json:"openai_codex_user_agent"`
// Payment visible method routing
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
@ -1404,6 +1413,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
}
if req.OpenAICodexUserAgent != nil {
normalized := strings.TrimSpace(*req.OpenAICodexUserAgent)
req.OpenAICodexUserAgent = &normalized
// 仅做长度上限保护,不限制具体格式(运维需要可自由调整 codex 版本号)
if len(normalized) > 512 {
response.Error(c, http.StatusBadRequest, "openai_codex_user_agent must be at most 512 characters")
return
}
}
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
@ -1597,6 +1615,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AntigravityUserAgentVersion
}(),
OpenAICodexUserAgent: func() string {
if req.OpenAICodexUserAgent != nil {
return *req.OpenAICodexUserAgent
}
return previousSettings.OpenAICodexUserAgent
}(),
PaymentVisibleMethodAlipaySource: func() string {
if req.PaymentVisibleMethodAlipaySource != nil {
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
@ -1956,6 +1980,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent,
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
@ -2411,6 +2436,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion {
changed = append(changed, "antigravity_user_agent_version")
}
if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent {
changed = append(changed, "openai_codex_user_agent")
}
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
changed = append(changed, "payment_visible_method_alipay_source")
}
@ -3339,3 +3367,160 @@ func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key,
}
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
}
// ListEmailTemplates returns all editable notification email templates.
// GET /api/v1/admin/settings/email-templates
func (h *SettingHandler) ListEmailTemplates(c *gin.Context) {
if h.notificationEmailService == nil {
response.InternalError(c, "notification email service is not configured")
return
}
events := h.notificationEmailService.ListEventInfos()
templates, err := h.notificationEmailService.ListTemplates(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.EmailTemplateListResponse{
Events: emailTemplateEventOptionsToDTO(events),
Locales: h.notificationEmailService.SupportedLocales(),
Templates: emailTemplateSummariesToDTO(templates),
Placeholders: emailTemplatePlaceholderUnion(events),
})
}
// GetEmailTemplate returns one editable notification email template.
// GET /api/v1/admin/settings/email-templates/:event/:locale
func (h *SettingHandler) GetEmailTemplate(c *gin.Context) {
if h.notificationEmailService == nil {
response.InternalError(c, "notification email service is not configured")
return
}
tmpl, err := h.notificationEmailService.GetTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"))
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, emailTemplateDetailToDTO(tmpl))
}
// UpdateEmailTemplate saves an override for one event/locale template.
// PUT /api/v1/admin/settings/email-templates/:event/:locale
func (h *SettingHandler) UpdateEmailTemplate(c *gin.Context) {
if h.notificationEmailService == nil {
response.InternalError(c, "notification email service is not configured")
return
}
var req dto.UpdateEmailTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tmpl, err := h.notificationEmailService.UpdateTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"), req.Subject, req.HTML)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, emailTemplateDetailToDTO(tmpl))
}
// RestoreOfficialEmailTemplate removes an override and returns the built-in template.
// POST /api/v1/admin/settings/email-templates/:event/:locale/restore-official
func (h *SettingHandler) RestoreOfficialEmailTemplate(c *gin.Context) {
if h.notificationEmailService == nil {
response.InternalError(c, "notification email service is not configured")
return
}
tmpl, err := h.notificationEmailService.RestoreOfficialTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"))
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, emailTemplateDetailToDTO(tmpl))
}
// PreviewEmailTemplate renders a template with safe sample variables without saving it.
// POST /api/v1/admin/settings/email-templates/preview
func (h *SettingHandler) PreviewEmailTemplate(c *gin.Context) {
if h.notificationEmailService == nil {
response.InternalError(c, "notification email service is not configured")
return
}
var req dto.PreviewEmailTemplateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
preview, err := h.notificationEmailService.PreviewTemplate(c.Request.Context(), service.NotificationEmailPreviewInput{
Event: req.Event,
Locale: req.Locale,
Subject: req.Subject,
HTML: req.HTML,
Variables: req.Variables,
})
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, dto.EmailTemplatePreviewResponse{Subject: preview.Subject, HTML: preview.HTML})
}
func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo) []dto.EmailTemplateEventOption {
items := make([]dto.EmailTemplateEventOption, 0, len(events))
for _, event := range events {
items = append(items, dto.EmailTemplateEventOption{
Value: event.Event,
Label: event.Label,
Description: event.Description,
})
}
return items
}
func emailTemplateSummariesToDTO(templates []service.NotificationEmailTemplate) []dto.EmailTemplateSummary {
items := make([]dto.EmailTemplateSummary, 0, len(templates))
for _, tmpl := range templates {
items = append(items, dto.EmailTemplateSummary{
Event: tmpl.Event,
Locale: tmpl.Locale,
Subject: tmpl.Subject,
IsCustom: tmpl.IsCustom,
UpdatedAt: emailTemplateUpdatedAt(tmpl),
})
}
return items
}
func emailTemplateDetailToDTO(tmpl service.NotificationEmailTemplate) dto.EmailTemplateDetail {
return dto.EmailTemplateDetail{
Event: tmpl.Event,
Locale: tmpl.Locale,
Subject: tmpl.Subject,
HTML: tmpl.HTML,
IsCustom: tmpl.IsCustom,
UpdatedAt: emailTemplateUpdatedAt(tmpl),
Placeholders: tmpl.Placeholders,
}
}
func emailTemplateUpdatedAt(tmpl service.NotificationEmailTemplate) string {
if tmpl.UpdatedAt == nil {
return ""
}
return tmpl.UpdatedAt.Format("2006-01-02T15:04:05Z07:00")
}
func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) []string {
seen := make(map[string]struct{})
placeholders := make([]string, 0)
for _, event := range events {
for _, placeholder := range event.Placeholders {
if _, ok := seen[placeholder]; ok {
continue
}
seen[placeholder] = struct{}{}
placeholders = append(placeholders, placeholder)
}
}
return placeholders
}

View File

@ -203,7 +203,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email, c.GetHeader("Accept-Language"))
if err != nil {
response.ErrorFrom(c, err)
return
@ -602,7 +602,7 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL, c.GetHeader("Accept-Language")); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@ -545,7 +545,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
return
}
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email, c.GetHeader("Accept-Language"))
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -181,6 +181,7 @@ type SystemSettings struct {
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
OpenAICodexUserAgent string `json:"openai_codex_user_agent"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
@ -377,6 +378,62 @@ type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// EmailTemplateEventOption describes an editable notification email event.
type EmailTemplateEventOption struct {
Value string `json:"value"`
Label string `json:"label,omitempty"`
Description string `json:"description,omitempty"`
}
// EmailTemplateSummary is shown in the admin email template list.
type EmailTemplateSummary struct {
Event string `json:"event"`
Locale string `json:"locale"`
Subject string `json:"subject"`
IsCustom bool `json:"is_custom,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
// EmailTemplateListResponse is returned by GET /admin/settings/email-templates.
type EmailTemplateListResponse struct {
Events []EmailTemplateEventOption `json:"events"`
Locales []string `json:"locales"`
Templates []EmailTemplateSummary `json:"templates,omitempty"`
Placeholders []string `json:"placeholders,omitempty"`
}
// EmailTemplateDetail is returned for a specific event/locale template.
type EmailTemplateDetail struct {
Event string `json:"event"`
Locale string `json:"locale"`
Subject string `json:"subject"`
HTML string `json:"html"`
IsCustom bool `json:"is_custom,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
Placeholders []string `json:"placeholders,omitempty"`
}
// UpdateEmailTemplateRequest updates a template override.
type UpdateEmailTemplateRequest struct {
Subject string `json:"subject"`
HTML string `json:"html"`
}
// PreviewEmailTemplateRequest previews a template without saving it.
type PreviewEmailTemplateRequest struct {
Event string `json:"event"`
Locale string `json:"locale"`
Subject string `json:"subject"`
HTML string `json:"html"`
Variables map[string]string `json:"variables,omitempty"`
}
// EmailTemplatePreviewResponse is the rendered preview payload.
type EmailTemplatePreviewResponse struct {
Subject string `json:"subject"`
HTML string `json:"html"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {

View File

@ -1133,9 +1133,15 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 解析可选的日期范围参数(用于 model_stats 查询)
startTime, endTime := h.parseUsageDateRange(c)
days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", ""))
if !ok {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Invalid days, allowed range is 1-90")
return
}
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
usageData := h.buildUsageData(ctx, apiKey.ID)
dailyUsage := h.buildAPIKeyDailyUsage(c, subject.UserID, apiKey.ID, days)
// Best-effort: 获取模型统计
var modelStats any
@ -1149,11 +1155,11 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
if isQuotaLimited {
h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats)
h.usageQuotaLimited(c, ctx, apiKey, usageData, dailyUsage, modelStats)
return
}
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats)
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, dailyUsage, modelStats)
}
// parseUsageDateRange 解析 start_date / end_date query params默认返回近 30 天范围
@ -1211,8 +1217,20 @@ func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin
}
}
func (h *GatewayHandler) buildAPIKeyDailyUsage(c *gin.Context, userID, apiKeyID int64, days int) any {
if h.usageService == nil {
return nil
}
startTime, endTime := apiKeyDailyUsageRange(days, c.Query("timezone"))
stats, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), userID, apiKeyID, startTime, endTime)
if err != nil {
return nil
}
return stats
}
// usageQuotaLimited 处理 quota_limited 模式的响应
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) {
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, dailyUsage any, modelStats any) {
resp := gin.H{
"mode": "quota_limited",
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
@ -1294,6 +1312,9 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
if usageData != nil {
resp["usage"] = usageData
}
if dailyUsage != nil {
resp["daily_usage"] = dailyUsage
}
if modelStats != nil {
resp["model_stats"] = modelStats
}
@ -1302,7 +1323,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
}
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) {
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, dailyUsage any, modelStats any) {
// 订阅模式
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
resp := gin.H{
@ -1331,6 +1352,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context,
if usageData != nil {
resp["usage"] = usageData
}
if dailyUsage != nil {
resp["daily_usage"] = dailyUsage
}
if modelStats != nil {
resp["model_stats"] = modelStats
}
@ -1356,6 +1380,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context,
if usageData != nil {
resp["usage"] = usageData
}
if dailyUsage != nil {
resp["daily_usage"] = dailyUsage
}
if modelStats != nil {
resp["model_stats"] = modelStats
}

View File

@ -266,6 +266,7 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
PaymentSource: req.PaymentSource,
OrderType: req.OrderType,
PlanID: req.PlanID,
Locale: c.GetHeader("Accept-Language"),
})
if err != nil {
response.ErrorFrom(c, err)

View File

@ -1,6 +1,10 @@
package handler
import (
"html"
"net/http"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@ -10,8 +14,9 @@ import (
// SettingHandler 公开设置处理器(无需认证)
type SettingHandler struct {
settingService *service.SettingService
version string
settingService *service.SettingService
notificationEmailService *service.NotificationEmailService
version string
}
// NewSettingHandler 创建公开设置处理器
@ -22,6 +27,12 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
}
}
// SetNotificationEmailService attaches the public notification email service without
// changing the constructor signature used by existing tests.
func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) {
h.notificationEmailService = notificationEmailService
}
// GetPublicSettings 获取公开设置
// GET /api/v1/settings/public
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
@ -90,6 +101,27 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
})
}
// UnsubscribeNotificationEmail handles optional notification email opt-outs.
// GET /api/v1/settings/email-unsubscribe?token=...
func (h *SettingHandler) UnsubscribeNotificationEmail(c *gin.Context) {
if h.notificationEmailService == nil {
response.InternalError(c, "notification email service is not configured")
return
}
token := strings.TrimSpace(c.Query("token"))
if token == "" {
response.BadRequest(c, "token is required")
return
}
result, err := h.notificationEmailService.Unsubscribe(c.Request.Context(), token)
if err != nil {
response.BadRequest(c, err.Error())
return
}
body := "<!doctype html><html><head><meta charset=\"utf-8\"><title>Unsubscribed</title></head><body style=\"font-family:-apple-system,BlinkMacSystemFont,Segoe UI,sans-serif;padding:32px;\"><h1>Unsubscribed</h1><p>You have unsubscribed <strong>" + html.EscapeString(result.Email) + "</strong> from <strong>" + html.EscapeString(result.Event) + "</strong> emails.</p></body></html>"
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(body))
}
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
result := make([]dto.LoginAgreementDocument, 0, len(items))
for _, item := range items {

View File

@ -172,7 +172,7 @@ func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
return
}
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID, c.GetHeader("Accept-Language")); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@ -298,6 +298,29 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
return startTime, endTime
}
const (
defaultAPIKeyDailyUsageDays = 30
maxAPIKeyDailyUsageDays = 90
)
func parseAPIKeyDailyUsageDays(raw string) (int, bool) {
if strings.TrimSpace(raw) == "" {
return defaultAPIKeyDailyUsageDays, true
}
days, err := strconv.Atoi(raw)
if err != nil || days <= 0 || days > maxAPIKeyDailyUsageDays {
return 0, false
}
return days, true
}
func apiKeyDailyUsageRange(days int, userTZ string) (time.Time, time.Time) {
now := timezone.NowInUserLocation(userTZ)
startTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -(days-1)), userTZ)
endTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
return startTime, endTime
}
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
@ -416,3 +439,55 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
response.Success(c, gin.H{"stats": stats})
}
// GetMyAPIKeyDailyUsage handles getting daily usage details for the current user's API key.
// GET /api/v1/user/api-keys/:id/usage/daily?days=30
func (h *UsageHandler) GetMyAPIKeyDailyUsage(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
apiKeyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", ""))
if !ok {
response.BadRequest(c, "Invalid days, allowed range is 1-90")
return
}
if h.apiKeyService == nil {
response.InternalError(c, "API key service is not configured")
return
}
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), apiKeyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage")
return
}
userTZ := c.Query("timezone")
startTime, endTime := apiKeyDailyUsageRange(days, userTZ)
items, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), subject.UserID, apiKeyID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"items": items,
"days": days,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.AddDate(0, 0, -1).Format("2006-01-02"),
})
}

View File

@ -0,0 +1,195 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dailyUsageRepoStub struct {
service.UsageLogRepository
trend []usagestats.TrendDataPoint
called bool
startTime time.Time
endTime time.Time
granularity string
userID int64
apiKeyID int64
}
func (s *dailyUsageRepoStub) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.called = true
s.startTime = startTime
s.endTime = endTime
s.granularity = granularity
s.userID = userID
s.apiKeyID = apiKeyID
return s.trend, nil
}
type dailyUsageAPIKeyRepoStub struct {
service.APIKeyRepository
keys map[int64]*service.APIKey
}
func (s *dailyUsageAPIKeyRepoStub) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := s.keys[id]
if !ok {
return nil, service.ErrAPIKeyNotFound
}
clone := *key
return &clone, nil
}
func newDailyUsageTestRouter(usageRepo *dailyUsageRepoStub, apiKeyRepo *dailyUsageAPIKeyRepoStub, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
usageSvc := service.NewUsageService(usageRepo, nil, nil, nil)
apiKeySvc := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, nil)
handler := NewUsageHandler(usageSvc, apiKeySvc)
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: userID})
c.Next()
})
router.GET("/user/api-keys/:id/usage/daily", handler.GetMyAPIKeyDailyUsage)
return router
}
type dailyUsageHandlerResponse struct {
Code int `json:"code"`
Data struct {
Items []usagestats.APIKeyDailyUsagePoint `json:"items"`
Days int `json:"days"`
} `json:"data"`
}
func TestGetMyAPIKeyDailyUsageRejectsCrossUserAccess(t *testing.T) {
usageRepo := &dailyUsageRepoStub{}
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
keys: map[int64]*service.APIKey{
7: {ID: 7, UserID: 99, Status: service.StatusAPIKeyActive},
},
}
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=30", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusForbidden, rec.Code)
require.False(t, usageRepo.called)
}
func TestGetMyAPIKeyDailyUsageRejectsInvalidDays(t *testing.T) {
for _, path := range []string{
"/user/api-keys/7/usage/daily?days=0",
"/user/api-keys/7/usage/daily?days=91",
} {
t.Run(path, func(t *testing.T) {
usageRepo := &dailyUsageRepoStub{}
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
keys: map[int64]*service.APIKey{
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
},
}
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
req := httptest.NewRequest(http.MethodGet, path, nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.False(t, usageRepo.called)
})
}
}
func TestGetMyAPIKeyDailyUsageReturnsEmptyData(t *testing.T) {
usageRepo := &dailyUsageRepoStub{trend: []usagestats.TrendDataPoint{}}
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
keys: map[int64]*service.APIKey{
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
},
}
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var got dailyUsageHandlerResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, 30, got.Data.Days)
require.Empty(t, got.Data.Items)
}
func TestGetMyAPIKeyDailyUsageAggregatesByDayForOwnedKey(t *testing.T) {
usageRepo := &dailyUsageRepoStub{
trend: []usagestats.TrendDataPoint{
{
Date: "2026-05-19",
Requests: 3,
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 4,
CacheReadTokens: 6,
TotalTokens: 40,
Cost: 0.5,
ActualCost: 0.4,
},
},
}
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
keys: map[int64]*service.APIKey{
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
},
}
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=7", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, usageRepo.called)
require.Equal(t, "day", usageRepo.granularity)
require.Equal(t, int64(42), usageRepo.userID)
require.Equal(t, int64(7), usageRepo.apiKeyID)
require.True(t, usageRepo.startTime.Before(usageRepo.endTime))
var got dailyUsageHandlerResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, 7, got.Data.Days)
require.Len(t, got.Data.Items, 1)
require.Equal(t, usagestats.APIKeyDailyUsagePoint{
Date: "2026-05-19",
Requests: 3,
InputTokens: 10,
OutputTokens: 20,
CacheReadTokens: 6,
CacheWriteTokens: 4,
TotalTokens: 40,
Cost: 0.5,
ActualCost: 0.4,
}, got.Data.Items[0])
}

View File

@ -335,7 +335,7 @@ func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
return
}
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email, c.GetHeader("Accept-Language")); err != nil {
response.ErrorFrom(c, err)
return
}
@ -363,7 +363,7 @@ func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
return
}
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache, c.GetHeader("Accept-Language"))
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -90,8 +90,17 @@ func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
return NewSettingHandler(settingService, buildInfo.Version)
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo, notificationEmailService *service.NotificationEmailService) *SettingHandler {
h := NewSettingHandler(settingService, buildInfo.Version)
h.SetNotificationEmailService(notificationEmailService)
return h
}
// ProvideAdminSettingHandler creates admin.SettingHandler with notification template APIs.
func ProvideAdminSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService, notificationEmailService *service.NotificationEmailService) *admin.SettingHandler {
h := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
h.SetNotificationEmailService(notificationEmailService)
return h
}
// ProvideHandlers creates the Handlers struct
@ -169,7 +178,7 @@ var ProviderSet = wire.NewSet(
admin.NewProxyHandler,
admin.NewRedeemHandler,
admin.NewPromoHandler,
admin.NewSettingHandler,
ProvideAdminSettingHandler,
admin.NewOpsHandler,
ProvideSystemHandler,
admin.NewSubscriptionHandler,

View File

@ -0,0 +1,719 @@
package apicompat
import (
"encoding/json"
"fmt"
"strings"
"time"
)
// ResponsesToChatCompletionsRequest converts a Responses API request into a
// Chat Completions request for upstreams that only implement
// /v1/chat/completions.
func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) {
if req == nil {
return nil, fmt.Errorf("responses request is nil")
}
messages, err := responsesInputToChatMessages(req.Instructions, req.Input)
if err != nil {
return nil, err
}
out := &ChatCompletionsRequest{
Model: req.Model,
Messages: messages,
MaxCompletionTokens: req.MaxOutputTokens,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: req.Stream,
ServiceTier: req.ServiceTier,
}
if req.Reasoning != nil {
out.ReasoningEffort = req.Reasoning.Effort
}
if len(req.Tools) > 0 {
out.Tools = responsesToolsToChatTools(req.Tools)
}
if len(req.ToolChoice) > 0 {
out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice)
}
return out, nil
}
func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) {
var messages []ChatMessage
if strings.TrimSpace(instructions) != "" {
content, _ := json.Marshal(instructions)
messages = append(messages, ChatMessage{
Role: "system",
Content: content,
})
}
inputRaw = bytesTrimSpace(inputRaw)
if len(inputRaw) == 0 || string(inputRaw) == "null" {
return messages, nil
}
var inputText string
if err := json.Unmarshal(inputRaw, &inputText); err == nil {
content, _ := json.Marshal(inputText)
messages = append(messages, ChatMessage{
Role: "user",
Content: content,
})
return messages, nil
}
var rawItems []json.RawMessage
if err := json.Unmarshal(inputRaw, &rawItems); err != nil {
return nil, fmt.Errorf("parse responses input: %w", err)
}
for _, raw := range rawItems {
raw = bytesTrimSpace(raw)
if len(raw) == 0 || string(raw) == "null" {
continue
}
var item map[string]json.RawMessage
if err := json.Unmarshal(raw, &item); err != nil {
var text string
if textErr := json.Unmarshal(raw, &text); textErr == nil {
content, _ := json.Marshal(text)
messages = append(messages, ChatMessage{Role: "user", Content: content})
continue
}
return nil, fmt.Errorf("parse responses input item: %w", err)
}
role := rawString(item["role"])
itemType := rawString(item["type"])
switch itemType {
case "function_call":
arguments := rawString(item["arguments"])
if strings.TrimSpace(arguments) == "" {
arguments = "{}"
}
messages = append(messages, ChatMessage{
Role: "assistant",
ToolCalls: []ChatToolCall{{
ID: rawString(item["call_id"]),
Type: "function",
Function: ChatFunctionCall{
Name: rawString(item["name"]),
Arguments: arguments,
},
}},
})
continue
case "function_call_output":
content, _ := json.Marshal(rawString(item["output"]))
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: rawString(item["call_id"]),
Content: content,
})
continue
case "input_text", "text":
content, _ := json.Marshal(rawString(item["text"]))
messages = append(messages, ChatMessage{Role: "user", Content: content})
continue
case "input_image":
content, err := chatContentFromSingleResponsesPart(itemType, item)
if err != nil {
return nil, err
}
messages = append(messages, ChatMessage{Role: "user", Content: content})
continue
}
if role == "" {
role = "user"
}
content := item["content"]
if len(bytesTrimSpace(content)) == 0 {
if text := rawString(item["text"]); text != "" {
content, _ = json.Marshal(text)
}
}
chatContent, err := responsesContentToChatContent(content, role)
if err != nil {
return nil, err
}
messages = append(messages, ChatMessage{
Role: role,
Content: chatContent,
})
}
return messages, nil
}
func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) {
raw = bytesTrimSpace(raw)
if len(raw) == 0 || string(raw) == "null" {
empty, _ := json.Marshal("")
return empty, nil
}
var text string
if err := json.Unmarshal(raw, &text); err == nil {
return raw, nil
}
var rawParts []json.RawMessage
if err := json.Unmarshal(raw, &rawParts); err == nil {
return responsesContentPartsToChatContent(rawParts, role)
}
var obj map[string]json.RawMessage
if err := json.Unmarshal(raw, &obj); err == nil {
return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj)
}
return raw, nil
}
func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) {
var textParts []string
var chatParts []ChatContentPart
hasNonText := false
for _, rawPart := range rawParts {
var part map[string]json.RawMessage
if err := json.Unmarshal(rawPart, &part); err != nil {
continue
}
partType := rawString(part["type"])
switch partType {
case "input_text", "output_text", "text", "":
text := rawString(part["text"])
if text == "" {
continue
}
textParts = append(textParts, text)
chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text})
case "input_image", "image_url":
imageURL := rawString(part["image_url"])
if imageURL == "" {
imageURL = rawNestedString(part["image_url"], "url")
}
if imageURL == "" {
continue
}
hasNonText = true
chatParts = append(chatParts, ChatContentPart{
Type: "image_url",
ImageURL: &ChatImageURL{URL: imageURL},
})
}
}
if !hasNonText {
joined, _ := json.Marshal(strings.Join(textParts, "\n\n"))
return joined, nil
}
if role != "user" {
joined, _ := json.Marshal(strings.Join(textParts, "\n\n"))
return joined, nil
}
if len(chatParts) == 0 {
empty, _ := json.Marshal("")
return empty, nil
}
return json.Marshal(chatParts)
}
func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) {
switch partType {
case "input_image", "image_url":
imageURL := rawString(part["image_url"])
if imageURL == "" {
imageURL = rawNestedString(part["image_url"], "url")
}
return json.Marshal([]ChatContentPart{{
Type: "image_url",
ImageURL: &ChatImageURL{URL: imageURL},
}})
default:
return json.Marshal(rawString(part["text"]))
}
}
func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool {
out := make([]ChatTool, 0, len(tools))
for _, tool := range tools {
if tool.Type != "function" {
continue
}
out = append(out, ChatTool{
Type: "function",
Function: &ChatFunction{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.Parameters,
Strict: tool.Strict,
},
})
}
return out
}
func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage {
var choice map[string]json.RawMessage
if err := json.Unmarshal(raw, &choice); err != nil {
return raw
}
if rawString(choice["type"]) != "function" {
return raw
}
name := rawString(choice["name"])
if name == "" {
name = rawNestedString(choice["function"], "name")
}
if name == "" {
return raw
}
out, err := json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{
"name": name,
},
})
if err != nil {
return raw
}
return out
}
// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions
// response into a Responses API response.
func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse {
id := ""
if resp != nil {
id = resp.ID
}
if id == "" {
id = generateResponsesID()
}
out := &ResponsesResponse{
ID: id,
Object: "response",
Model: model,
Status: "completed",
}
if resp == nil {
out.Output = []ResponsesOutput{emptyResponsesMessageOutput()}
return out
}
if out.Model == "" {
out.Model = resp.Model
}
if len(resp.Choices) > 0 {
choice := resp.Choices[0]
out.Output = chatMessageToResponsesOutput(choice.Message)
if choice.FinishReason == "length" {
out.Status = "incomplete"
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
}
}
if len(out.Output) == 0 {
out.Output = []ResponsesOutput{emptyResponsesMessageOutput()}
}
if resp.Usage != nil {
out.Usage = ChatUsageToResponsesUsage(resp.Usage)
}
return out
}
func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput {
var outputs []ResponsesOutput
if message.ReasoningContent != "" {
outputs = append(outputs, ResponsesOutput{
Type: "reasoning",
ID: generateItemID(),
Summary: []ResponsesSummary{{
Type: "summary_text",
Text: message.ReasoningContent,
}},
})
}
text := chatMessageContentText(message.Content)
if text != "" || len(message.ToolCalls) == 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: []ResponsesContentPart{{
Type: "output_text",
Text: text,
}},
Status: "completed",
})
}
for _, toolCall := range message.ToolCalls {
arguments := toolCall.Function.Arguments
if strings.TrimSpace(arguments) == "" {
arguments = "{}"
}
outputs = append(outputs, ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: toolCall.ID,
Name: toolCall.Function.Name,
Arguments: arguments,
Status: "completed",
})
}
return outputs
}
func emptyResponsesMessageOutput() ResponsesOutput {
return ResponsesOutput{
Type: "message",
ID: generateItemID(),
Role: "assistant",
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
Status: "completed",
}
}
func chatMessageContentText(raw json.RawMessage) string {
raw = bytesTrimSpace(raw)
if len(raw) == 0 || string(raw) == "null" {
return ""
}
var text string
if err := json.Unmarshal(raw, &text); err == nil {
return text
}
var parts []ChatContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
var texts []string
for _, part := range parts {
if part.Type == "text" && part.Text != "" {
texts = append(texts, part.Text)
}
}
return strings.Join(texts, "\n\n")
}
return ""
}
// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses
// usage shape.
func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage {
if usage == nil {
return nil
}
out := &ResponsesUsage{
InputTokens: usage.PromptTokens,
OutputTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
}
if out.TotalTokens == 0 {
out.TotalTokens = out.InputTokens + out.OutputTokens
}
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
out.InputTokensDetails = &ResponsesInputTokensDetails{
CachedTokens: usage.PromptTokensDetails.CachedTokens,
}
}
return out
}
// ChatCompletionsToResponsesStreamState tracks state while converting Chat
// Completions SSE chunks into Responses SSE events.
type ChatCompletionsToResponsesStreamState struct {
ResponseID string
Model string
Created int64
SequenceNumber int
CreatedSent bool
CompletedSent bool
MessageItemID string
Text strings.Builder
Reasoning strings.Builder
ToolCalls map[int]*ChatToolCall
FinishReason string
Usage *ResponsesUsage
}
// NewChatCompletionsToResponsesStreamState returns an initialized stream state.
func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState {
return &ChatCompletionsToResponsesStreamState{
ResponseID: generateResponsesID(),
Model: model,
Created: time.Now().Unix(),
ToolCalls: make(map[int]*ChatToolCall),
}
}
// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream
// chunk into zero or more Responses stream events.
func ChatCompletionsChunkToResponsesEvents(
chunk *ChatCompletionsChunk,
state *ChatCompletionsToResponsesStreamState,
) []ResponsesStreamEvent {
if chunk == nil || state == nil {
return nil
}
if chunk.ID != "" {
state.ResponseID = chunk.ID
}
if state.Model == "" && chunk.Model != "" {
state.Model = chunk.Model
}
if chunk.Usage != nil {
state.Usage = ChatUsageToResponsesUsage(chunk.Usage)
}
var events []ResponsesStreamEvent
events = append(events, ensureChatToResponsesCreated(state)...)
for _, choice := range chunk.Choices {
if choice.Delta.Content != nil {
events = append(events, ensureChatToResponsesMessageItem(state)...)
_, _ = state.Text.WriteString(*choice.Delta.Content)
events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
OutputIndex: 0,
ContentIndex: 0,
Delta: *choice.Delta.Content,
ItemID: state.MessageItemID,
}))
}
if choice.Delta.ReasoningContent != nil {
_, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent)
events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
OutputIndex: 0,
SummaryIndex: 0,
Delta: *choice.Delta.ReasoningContent,
}))
}
for _, toolCall := range choice.Delta.ToolCalls {
idx := 0
if toolCall.Index != nil {
idx = *toolCall.Index
}
stored, ok := state.ToolCalls[idx]
if !ok {
copyCall := toolCall
if copyCall.ID == "" {
copyCall.ID = generateItemID()
}
copyCall.Type = "function"
state.ToolCalls[idx] = &copyCall
stored = &copyCall
events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: idx + 1,
Item: &ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: stored.ID,
Name: stored.Function.Name,
Status: "in_progress",
},
}))
} else {
if toolCall.ID != "" {
stored.ID = toolCall.ID
}
if toolCall.Function.Name != "" {
stored.Function.Name = toolCall.Function.Name
}
}
if toolCall.Function.Arguments != "" {
stored.Function.Arguments += toolCall.Function.Arguments
events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
OutputIndex: idx + 1,
Delta: toolCall.Function.Arguments,
CallID: stored.ID,
Name: stored.Function.Name,
}))
}
}
if choice.FinishReason != nil && *choice.FinishReason != "" {
state.FinishReason = *choice.FinishReason
}
}
return events
}
// FinalizeChatCompletionsResponsesStream emits terminal Responses events.
func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
if state == nil || state.CompletedSent {
return nil
}
var events []ResponsesStreamEvent
events = append(events, ensureChatToResponsesCreated(state)...)
if state.MessageItemID != "" {
events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
OutputIndex: 0,
ContentIndex: 0,
Text: state.Text.String(),
ItemID: state.MessageItemID,
}))
events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
OutputIndex: 0,
Item: &ResponsesOutput{
Type: "message",
ID: state.MessageItemID,
Role: "assistant",
Status: "completed",
},
}))
}
status := "completed"
var incompleteDetails *ResponsesIncompleteDetails
if state.FinishReason == "length" {
status = "incomplete"
incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
}
state.CompletedSent = true
events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: status,
Output: state.chatOutput(),
Usage: state.Usage,
IncompleteDetails: incompleteDetails,
},
}))
return events
}
func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
if state.CreatedSent {
return nil
}
state.CreatedSent = true
return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{
Response: &ResponsesResponse{
ID: state.ResponseID,
Object: "response",
Model: state.Model,
Status: "in_progress",
Output: []ResponsesOutput{},
},
})}
}
func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
if state.MessageItemID != "" {
return nil
}
state.MessageItemID = generateItemID()
return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
OutputIndex: 0,
Item: &ResponsesOutput{
Type: "message",
ID: state.MessageItemID,
Role: "assistant",
Status: "in_progress",
},
})}
}
func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput {
var outputs []ResponsesOutput
if state.Reasoning.Len() > 0 {
outputs = append(outputs, ResponsesOutput{
Type: "reasoning",
ID: generateItemID(),
Summary: []ResponsesSummary{{
Type: "summary_text",
Text: state.Reasoning.String(),
}},
})
}
if state.MessageItemID != "" || len(state.ToolCalls) == 0 {
outputs = append(outputs, ResponsesOutput{
Type: "message",
ID: nonEmpty(state.MessageItemID, generateItemID()),
Role: "assistant",
Content: []ResponsesContentPart{{
Type: "output_text",
Text: state.Text.String(),
}},
Status: "completed",
})
}
for i := 0; i < len(state.ToolCalls); i++ {
toolCall, ok := state.ToolCalls[i]
if !ok || toolCall == nil {
continue
}
arguments := toolCall.Function.Arguments
if strings.TrimSpace(arguments) == "" {
arguments = "{}"
}
outputs = append(outputs, ResponsesOutput{
Type: "function_call",
ID: generateItemID(),
CallID: toolCall.ID,
Name: toolCall.Function.Name,
Arguments: arguments,
Status: "completed",
})
}
return outputs
}
func chatToResponsesEvent(
state *ChatCompletionsToResponsesStreamState,
eventType string,
template *ResponsesStreamEvent,
) ResponsesStreamEvent {
seq := state.SequenceNumber
state.SequenceNumber++
evt := *template
evt.Type = eventType
evt.SequenceNumber = seq
return evt
}
func rawString(raw json.RawMessage) string {
raw = bytesTrimSpace(raw)
if len(raw) == 0 || string(raw) == "null" {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
return ""
}
func rawNestedString(raw json.RawMessage, key string) string {
var obj map[string]json.RawMessage
if err := json.Unmarshal(raw, &obj); err != nil {
return ""
}
return rawString(obj[key])
}
func bytesTrimSpace(raw json.RawMessage) json.RawMessage {
return json.RawMessage(strings.TrimSpace(string(raw)))
}
func nonEmpty(value, fallback string) string {
if value != "" {
return value
}
return fallback
}

View File

@ -30,6 +30,17 @@ var CodexOfficialClientOriginatorPrefixes = []string{
"codex ",
}
// IsBrowserUserAgent 判断 User-Agent 是否来自浏览器Chrome/Firefox/Safari/Edge/Opera 等)。
// 所有现代浏览器的 UA 均以 "Mozilla/" 作为前缀CLI 工具codex/claude/curl/postman/python-requests 等)不会。
// 该判定用于避免 Cloudflare 对浏览器型 UA 在 OpenAI 上游接口上触发 JS 质询。
func IsBrowserUserAgent(userAgent string) bool {
ua := strings.TrimSpace(userAgent)
if ua == "" {
return false
}
return strings.HasPrefix(strings.ToLower(ua), "mozilla/")
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
ua := normalizeCodexClientHeader(userAgent)

View File

@ -198,6 +198,19 @@ type APIKeyUsageTrendPoint struct {
Tokens int64 `json:"tokens"`
}
// APIKeyDailyUsagePoint represents one day of usage for a single API key.
type APIKeyDailyUsagePoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
CacheWriteTokens int64 `json:"cache_write_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计

View File

@ -0,0 +1,219 @@
//go:build unit
package repository
import (
"encoding/base64"
"encoding/hex"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ── 测试辅助 ─────────────────────────────────────────────────────────────────
// aesHexKey 构造一个全填充为 b 的 n 字节密钥并以 hex 编码返回。
func aesHexKey(n int, b byte) string {
raw := make([]byte, n)
for i := range raw {
raw[i] = b
}
return hex.EncodeToString(raw)
}
// aesTestCfg 用给定 hex 密钥字符串构造最小 Config。
func aesTestCfg(keyHex string) *config.Config {
return &config.Config{
Totp: config.TotpConfig{EncryptionKey: keyHex},
}
}
// aesEncryptor 创建一个持有合法 32 字节密钥的加密器,测试失败时立即终止。
func aesEncryptor(t *testing.T) *AESEncryptor {
t.Helper()
enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x42)))
require.NoError(t, err)
require.NotNil(t, enc)
return enc.(*AESEncryptor)
}
// ── NewAESEncryptor ──────────────────────────────────────────────────────────
func TestNewAESEncryptor_ValidKey32Bytes(t *testing.T) {
enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x01)))
require.NoError(t, err)
require.NotNil(t, enc)
}
// 16 / 24 字节密钥在 AES 体系内合法,但本实现仅接受 AES-25632 字节)。
func TestNewAESEncryptor_WrongKeyLength(t *testing.T) {
tests := []struct {
name string
keySize int
}{
{"16_bytes_AES128", 16},
{"24_bytes_AES192", 24},
{"1_byte", 1},
{"31_bytes", 31},
{"33_bytes", 33},
{"64_bytes", 64},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewAESEncryptor(aesTestCfg(aesHexKey(tt.keySize, 0x00)))
require.Error(t, err)
assert.Contains(t, err.Error(), "32 bytes")
})
}
}
// "配置缺失"场景:空字符串与非法 hex 编码。
func TestNewAESEncryptor_MissingOrInvalidConfig(t *testing.T) {
tests := []struct {
name string
keyHex string
wantContain string
}{
{"empty_key", "", "32 bytes"},
{"invalid_hex_odd_length", "abcde", "invalid totp encryption key"},
{"invalid_hex_chars", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", "invalid totp encryption key"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewAESEncryptor(aesTestCfg(tt.keyHex))
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantContain)
})
}
}
// ── 加解密往返Roundtrip───────────────────────────────────────────────────
func TestAESEncryptor_RoundTrip(t *testing.T) {
enc := aesEncryptor(t)
tests := []struct {
name string
plaintext string
}{
{"ascii", "Hello, Sub2API!"},
{"chinese_multibyte", "你好,世界!这是多字节 UTF-8 文本。"},
{"empty_string", ""},
{"long_string_gt_1KB", strings.Repeat("x", 2048)},
{"special_chars", "!@#$%^&*()_+-=[]{}|;':\",./<>?"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ct, err := enc.Encrypt(tt.plaintext)
require.NoError(t, err)
require.NotEmpty(t, ct, "密文不应为空(即便明文为空字符串)")
got, err := enc.Decrypt(ct)
require.NoError(t, err)
assert.Equal(t, tt.plaintext, got)
})
}
}
// ── IV/Nonce 随机性 ──────────────────────────────────────────────────────────
func TestAESEncryptor_Encrypt_NonceRandomness(t *testing.T) {
enc := aesEncryptor(t)
const iterations = 30
plaintext := "same plaintext for every iteration"
seen := make(map[string]struct{}, iterations)
for i := 0; i < iterations; i++ {
ct, err := enc.Encrypt(plaintext)
require.NoError(t, err)
seen[ct] = struct{}{}
}
// 30 次加密相同明文,每次因随机 Nonce 应产生不同密文。
assert.Len(t, seen, iterations,
"每次加密应因随机 Nonce 产生唯一密文,共 %d 次", iterations)
}
// ── Decrypt 错误路径 ──────────────────────────────────────────────────────────
func TestAESDecrypt_InvalidBase64(t *testing.T) {
enc := aesEncryptor(t)
_, err := enc.Decrypt("!!!not-valid-base64!!!")
require.Error(t, err)
assert.Contains(t, err.Error(), "decode base64")
}
func TestAESDecrypt_TooShort(t *testing.T) {
enc := aesEncryptor(t)
// GCM Nonce 为 12 字节;仅提供 2 字节,必然短于 NonceSize。
short := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02})
_, err := enc.Decrypt(short)
require.Error(t, err)
assert.Contains(t, err.Error(), "too short")
}
func TestAESDecrypt_TamperedCiphertext(t *testing.T) {
enc := aesEncryptor(t)
ct, err := enc.Encrypt("sensitive payload")
require.NoError(t, err)
raw, err := base64.StdEncoding.DecodeString(ct)
require.NoError(t, err)
// Nonce 占前 12 字节;翻转其后第一个字节(密文体)。
raw[12] ^= 0xFF
_, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw))
require.Error(t, err, "篡改密文体后解密应失败")
}
func TestAESDecrypt_TamperedTag(t *testing.T) {
enc := aesEncryptor(t)
ct, err := enc.Encrypt("sensitive payload")
require.NoError(t, err)
raw, err := base64.StdEncoding.DecodeString(ct)
require.NoError(t, err)
// GCM 认证标签占最后 16 字节;翻转最后一个字节。
raw[len(raw)-1] ^= 0xFF
_, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw))
require.Error(t, err, "篡改 GCM 标签后解密应失败")
}
// ── 跨实例Cross-instance──────────────────────────────────────────────────
func TestAESEncryptor_CrossInstance_SameKey_CanDecrypt(t *testing.T) {
keyHex := aesHexKey(32, 0xDE)
enc1, err := NewAESEncryptor(aesTestCfg(keyHex))
require.NoError(t, err)
enc2, err := NewAESEncryptor(aesTestCfg(keyHex))
require.NoError(t, err)
plaintext := "cross-instance roundtrip"
ct, err := enc1.Encrypt(plaintext)
require.NoError(t, err)
got, err := enc2.Decrypt(ct)
require.NoError(t, err)
assert.Equal(t, plaintext, got, "相同密钥构造的两个实例应可互相解密")
}
func TestAESEncryptor_CrossInstance_DifferentKey_CannotDecrypt(t *testing.T) {
enc1, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xAA)))
require.NoError(t, err)
enc2, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xBB)))
require.NoError(t, err)
ct, err := enc1.Encrypt("secret message")
require.NoError(t, err)
_, err = enc2.Decrypt(ct)
require.Error(t, err, "不同密钥的实例不应能解密对方的密文")
}

View File

@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
func TestGroupRepository_DeleteCascade_PreservesApiKeyGroupID(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
@ -138,8 +138,10 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
// API keys bound to the deleted group should have group_id cleared.
// API keys keep their group_id so auth can reject keys bound to a deleted group.
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
require.NotNil(t, keyAfter.GroupID)
require.Equal(t, targetGroup.ID, *keyAfter.GroupID)
require.Nil(t, keyAfter.Group)
}

View File

@ -9,7 +9,6 @@ import (
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@ -94,9 +93,13 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
if err != nil {
return nil, err
}
total, active, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = total
out.ActiveAccountCount = active
counts, err := r.loadAccountCounts(ctx, []int64{out.ID})
if err == nil {
c := counts[out.ID]
out.AccountCount = c.Total
out.ActiveAccountCount = c.Active
out.RateLimitedAccountCount = c.RateLimited
}
return out, nil
}
@ -538,15 +541,12 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
var rateLimited int64
err = scanSingleRow(ctx, r.sql,
`SELECT COUNT(*),
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
))
fmt.Sprintf(`SELECT
COUNT(*) FILTER (WHERE a.deleted_at IS NULL),
COUNT(*) FILTER (WHERE %s),
COUNT(*) FILTER (WHERE %s)
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = $1`,
WHERE ag.group_id = $1`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL),
[]any{groupID}, &total, &active, &rateLimited)
return
}
@ -636,28 +636,18 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
}
}
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {
return nil, err
}
// 3. Remove the group id from user_allowed_groups join table.
// 2. Remove the group id from user_allowed_groups join table.
// Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
// 3. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
// 5. Soft-delete group itself.
// 4. Soft-delete group itself.
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
return nil, err
}
@ -680,6 +670,28 @@ type groupAccountCounts struct {
RateLimited int64
}
const (
// 分组页的"可用"账号数必须与账号仓储的 ListSchedulableByGroupID 过滤口径一致。
groupAccountAvailableSQL = `a.deleted_at IS NULL
AND a.status = 'active'
AND a.schedulable = true
AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE)
AND (a.rate_limit_reset_at IS NULL OR a.rate_limit_reset_at <= NOW())
AND (a.overload_until IS NULL OR a.overload_until <= NOW())
AND (a.temp_unschedulable_until IS NULL OR a.temp_unschedulable_until <= NOW())`
// 这里沿用历史字段名 RateLimitedAccountCount但统计的是会让账号暂时退出调度的时间窗口。
groupAccountTemporarilyLimitedSQL = `a.deleted_at IS NULL
AND a.status = 'active'
AND a.schedulable = true
AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE)
AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)`
)
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
counts = make(map[int64]groupAccountCounts, len(groupIDs))
if len(groupIDs) == 0 {
@ -688,18 +700,14 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
rows, err := r.sql.QueryContext(
ctx,
`SELECT ag.group_id,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)) AS rate_limited
fmt.Sprintf(`SELECT ag.group_id,
COUNT(*) FILTER (WHERE a.deleted_at IS NULL) AS total,
COUNT(*) FILTER (WHERE %s) AS active,
COUNT(*) FILTER (WHERE %s) AS rate_limited
FROM account_groups ag
JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = ANY($1)
GROUP BY ag.group_id`,
GROUP BY ag.group_id`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL),
pq.Array(groupIDs),
)
if err != nil {

View File

@ -651,6 +651,164 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
s.Require().Zero(count)
}
// TestListWithFilters_ActiveAccountCount_LessThanTotal 验证 ActiveAccountCount 正确区分可用与不可用账号。
// 当分组内存在 disabled 或 schedulable=false 的账号时ActiveAccountCount 必须小于 AccountCount
// 且与 GetAccountCount 返回的 active 值一致。
func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() {
g := &service.Group{
Name: "g-mixed-status",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
insertAccount := func(name, status string, schedulable bool) int64 {
var id int64
s.Require().NoError(scanSingleRow(
s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, status, schedulable) VALUES ($1, $2, $3, $4, $5) RETURNING id",
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status, schedulable},
&id,
))
return id
}
link := func(accountID int64, priority int) {
_, err := s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
accountID, g.ID, priority)
s.Require().NoError(err)
}
// account 1: active + schedulable → counts toward both total and active
link(insertAccount("acc-active-sched", service.StatusActive, true), 1)
// account 2: disabled → counts toward total only
link(insertAccount("acc-disabled", service.StatusDisabled, true), 2)
// account 3: active + not schedulable → counts toward total only
link(insertAccount("acc-unschedulable", service.StatusActive, false), 3)
// --- ListWithFilters path ---
isExclusive := false
groups, _, err := s.repo.ListWithFilters(s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 100},
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err)
var found *service.Group
for i := range groups {
if groups[i].ID == g.ID {
found = &groups[i]
break
}
}
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
s.Assert().Equal(int64(3), found.AccountCount, "AccountCount must count all 3 accounts")
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must count only the active+schedulable account")
// --- GetAccountCount must return identical values ---
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err)
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
}
// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。
// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池,
// 因此 ActiveAccountCount 必须与真实调度查询口径一致。
func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() {
g := &service.Group{
Name: "g-rate-limited",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
var normalID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"acc-normal", service.PlatformAnthropic, service.AccountTypeOAuth},
&normalID))
var rateLimitedID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, rate_limit_reset_at) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
[]any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth},
&rateLimitedID))
var overloadedID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
[]any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth},
&overloadedID))
var tempUnschedulableID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
[]any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth},
&tempUnschedulableID))
var expiredID int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id",
[]any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth},
&expiredID))
_, err := s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
normalID, g.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
rateLimitedID, g.ID, 2)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
overloadedID, g.ID, 3)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
tempUnschedulableID, g.ID, 4)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
expiredID, g.ID, 5)
s.Require().NoError(err)
isExclusive := false
groups, _, err := s.repo.ListWithFilters(s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 100},
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err)
var found *service.Group
for i := range groups {
if groups[i].ID == g.ID {
found = &groups[i]
break
}
}
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts")
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts")
s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts")
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err)
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
detail, err := s.repo.GetByID(s.ctx, g.ID)
s.Require().NoError(err)
s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters")
s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters")
s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters")
}
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {

View File

@ -7,6 +7,67 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TestListWithAccountCountSort_AttachesActiveCount 验证通过 account_count 排序时,
// ActiveAccountCount 与 AccountCount 都被正确附加到返回结果中,
// 且排序基于 total 账号数而非 active 账号数。
func (s *GroupRepoSuite) TestListWithAccountCountSort_AttachesActiveCount() {
// Group A: 2 total, 1 active (1 disabled account)
gA := &service.Group{Name: "sort-count-a", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard}
// Group B: 1 total, 1 active
gB := &service.Group{Name: "sort-count-b", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard}
s.Require().NoError(s.repo.Create(s.ctx, gA))
s.Require().NoError(s.repo.Create(s.ctx, gB))
insertAccount := func(name, status string) int64 {
var id int64
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
"INSERT INTO accounts (name, platform, type, status) VALUES ($1, $2, $3, $4) RETURNING id",
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status},
&id))
return id
}
link := func(accountID, groupID int64, priority int) {
_, err := s.tx.ExecContext(s.ctx,
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
accountID, groupID, priority)
s.Require().NoError(err)
}
// gA: 1 active + 1 disabled → total=2, active=1
link(insertAccount("sa-active", service.StatusActive), gA.ID, 1)
link(insertAccount("sa-disabled", service.StatusDisabled), gA.ID, 2)
// gB: 1 active → total=1, active=1
link(insertAccount("sb-active", service.StatusActive), gB.ID, 1)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1, PageSize: 100, SortBy: "account_count", SortOrder: "desc",
}, service.PlatformAnthropic, service.StatusActive, "", nil)
s.Require().NoError(err)
byID := make(map[int64]service.Group, len(groups))
for _, g := range groups {
byID[g.ID] = g
}
s.Require().Contains(byID, gA.ID, "gA must appear in results")
s.Require().Contains(byID, gB.ID, "gB must appear in results")
cA := byID[gA.ID]
s.Assert().Equal(int64(2), cA.AccountCount, "gA AccountCount must be 2")
s.Assert().Equal(int64(1), cA.ActiveAccountCount, "gA ActiveAccountCount must be 1")
cB := byID[gB.ID]
s.Assert().Equal(int64(1), cB.AccountCount, "gB AccountCount must be 1")
s.Assert().Equal(int64(1), cB.ActiveAccountCount, "gB ActiveAccountCount must be 1")
// Sort is by total (not active): gA (total=2) must rank higher than gB (total=1) in desc order
indexByID := make(map[int64]int, len(groups))
for i, g := range groups {
indexByID[g.ID] = i
}
s.Assert().Less(indexByID[gA.ID], indexByID[gB.ID], "gA (total=2) must rank above gB (total=1) with account_count desc")
}
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}

View File

@ -833,6 +833,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_codex_user_agent": "",
"openai_fast_policy_settings": {
"rules": []
},
@ -1058,6 +1059,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_codex_user_agent": "",
"openai_fast_policy_settings": {
"rules": []
},

View File

@ -109,6 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
if abortIfAPIKeyGroupUnavailable(c, apiKey) {
return
}
// ── 4. SimpleMode → early return ─────────────────────────────
@ -251,3 +254,26 @@ func setGroupContext(c *gin.Context, group *service.Group) {
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
c.Request = c.Request.WithContext(ctx)
}
func abortIfAPIKeyGroupUnavailable(c *gin.Context, apiKey *service.APIKey) bool {
code, message, ok := validateAPIKeyGroupAvailable(apiKey)
if ok {
return false
}
AbortWithError(c, 403, code, message)
return true
}
func validateAPIKeyGroupAvailable(apiKey *service.APIKey) (string, string, bool) {
if apiKey == nil || apiKey.GroupID == nil {
return "", "", true
}
group := apiKey.Group
if group == nil || strings.EqualFold(group.Status, "deleted") {
return "GROUP_DELETED", "API Key 所属分组已删除", false
}
if !group.IsActive() {
return "GROUP_DISABLED", "API Key 所属分组已停用", false
}
return "", "", true
}

View File

@ -54,6 +54,10 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError(c, 401, "User account is not active")
return
}
if _, message, ok := validateAPIKeyGroupAvailable(apiKey); !ok {
abortWithGoogleError(c, 403, message)
return
}
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {

View File

@ -300,6 +300,104 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(101)
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
tests := []struct {
name string
group *service.Group
wantStatus int
wantCode string
}{
{
name: "active group passes",
group: &service.Group{
ID: groupID,
Name: "active",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
},
wantStatus: http.StatusOK,
},
{
name: "disabled group is forbidden",
group: &service.Group{
ID: groupID,
Name: "disabled",
Status: service.StatusDisabled,
Platform: service.PlatformAnthropic,
Hydrated: true,
},
wantStatus: http.StatusForbidden,
wantCode: "GROUP_DISABLED",
},
{
name: "deleted status group is forbidden",
group: &service.Group{
ID: groupID,
Name: "deleted",
Status: "deleted",
Platform: service.PlatformAnthropic,
Hydrated: true,
},
wantStatus: http.StatusForbidden,
wantCode: "GROUP_DELETED",
},
{
name: "missing group edge is forbidden",
group: nil,
wantStatus: http.StatusForbidden,
wantCode: "GROUP_DELETED",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
GroupID: &groupID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: tt.group,
}
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, nil, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, tt.wantStatus, w.Code)
if tt.wantCode != "" {
require.Contains(t, w.Body.String(), tt.wantCode)
}
})
}
}
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -422,6 +422,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
adminSettings.GET("/email-templates", h.Admin.Setting.ListEmailTemplates)
adminSettings.POST("/email-template-preview", h.Admin.Setting.PreviewEmailTemplate)
adminSettings.GET("/email-templates/:event/:locale", h.Admin.Setting.GetEmailTemplate)
adminSettings.PUT("/email-templates/:event/:locale", h.Admin.Setting.UpdateEmailTemplate)
adminSettings.POST("/email-templates/:event/:locale/restore-official", h.Admin.Setting.RestoreOfficialEmailTemplate)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)

View File

@ -214,6 +214,7 @@ func RegisterAuthRoutes(
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
settings.GET("/email-unsubscribe", h.Setting.UnsubscribeNotificationEmail)
}
// 需要认证的当前用户信息

View File

@ -31,6 +31,7 @@ func RegisterUserRoutes(
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")

View File

@ -244,6 +244,21 @@ func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSor
return nil
}
type deleteGroupAPIKeyRepoStub struct {
apiKeyRepoStubForGroupUpdate
keys []string
listErr error
listGroupIDs []int64
}
func (s *deleteGroupAPIKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
s.listGroupIDs = append(s.listGroupIDs, groupID)
if s.listErr != nil {
return nil, s.listErr
}
return s.keys, nil
}
type proxyRepoStub struct {
deleteErr error
countErr error
@ -500,6 +515,23 @@ func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) {
}, calls)
}
func TestAdminService_DeleteGroup_InvalidatesAuthCacheForBoundKeys(t *testing.T) {
repo := &groupRepoStub{}
apiKeyRepo := &deleteGroupAPIKeyRepoStub{keys: []string{"k1", "k2"}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
groupRepo: repo,
apiKeyRepo: apiKeyRepo,
authCacheInvalidator: invalidator,
}
err := svc.DeleteGroup(context.Background(), 5)
require.NoError(t, err)
require.Equal(t, []int64{5}, repo.deleteCalls)
require.Equal(t, []int64{5}, apiKeyRepo.listGroupIDs)
require.Equal(t, []string{"k1", "k2"}, invalidator.keys)
}
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
svc := &adminServiceImpl{groupRepo: repo}

View File

@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks
type apiKeyAuthCacheConfig struct {
l1Size int

View File

@ -94,7 +94,7 @@ func (s *AuthService) BindEmailIdentity(
}
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string, locale ...string) error {
if s == nil {
return ErrServiceUnavailable
}
@ -128,7 +128,7 @@ func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int6
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName, firstEmailLocale(locale))
}
func normalizeEmailForIdentityBinding(email string) (string, error) {

View File

@ -28,7 +28,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
// account-creation flows without relying on the public registration gate.
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return nil, ErrEmailVerifyRequired
@ -47,7 +47,7 @@ func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email stri
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
if err := s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)); err != nil {
return nil, err
}
return &SendVerifyCodeResult{

View File

@ -273,7 +273,7 @@ type SendVerifyCodeResult struct {
}
// SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func (s *AuthService) SendVerifyCode(ctx context.Context, email string, locale ...string) error {
// 检查是否开放注册(默认关闭)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled
@ -307,11 +307,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, email, siteName)
return s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale))
}
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) {
logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册(默认关闭)
@ -352,7 +352,7 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
// 异步发送
logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email)
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName, firstEmailLocale(locale)); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err)
return nil, fmt.Errorf("enqueue verify code: %w", err)
}
@ -1251,7 +1251,7 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string, locale ...string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
@ -1264,7 +1264,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
return nil // Silent success to prevent enumeration
}
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
@ -1275,7 +1275,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string, locale ...string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
@ -1288,7 +1288,7 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron
return nil // Silent success to prevent enumeration
}
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err)
return nil // Silent success to prevent enumeration
}

View File

@ -39,9 +39,10 @@ type AccountQuotaReader interface {
// BalanceNotifyService handles balance and quota threshold notifications.
type BalanceNotifyService struct {
emailService *EmailService
settingRepo SettingRepository
accountRepo AccountQuotaReader
emailService *EmailService
settingRepo SettingRepository
accountRepo AccountQuotaReader
notificationEmailService *NotificationEmailService
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
@ -53,6 +54,10 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo
}
}
func (s *BalanceNotifyService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
s.notificationEmailService = notificationEmailService
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
@ -125,7 +130,7 @@ func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user
slog.Error("panic in balance notification", "recover", r)
}
}()
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
s.sendBalanceLowEmails(recipients, user.ID, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
}()
}
@ -342,11 +347,44 @@ func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body str
}
// sendBalanceLowEmails sends balance low notification to all recipients.
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userID int64, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
displayName := userName
if displayName == "" {
displayName = userEmail
}
if s.notificationEmailService != nil {
fallbackRecipients := make([]string, 0, len(recipients))
for _, to := range recipients {
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventBalanceLow,
RecipientEmail: to,
RecipientName: displayName,
UserID: userID,
SourceType: "balance_low",
SourceID: firstNonEmpty(strconv.FormatInt(userID, 10), userEmail),
ReminderKey: time.Now().UTC().Format("2006-01-02"),
Variables: map[string]string{
"current_balance": fmt.Sprintf("%.2f", balance),
"threshold": fmt.Sprintf("%.2f", threshold),
"recharge_url": rechargeURL,
},
})
cancel()
if err != nil {
if shouldFallbackNotificationEmail(err) {
slog.Warn("template balance low notification failed; falling back to built-in body", "to", to, "err", err.Error())
fallbackRecipients = append(fallbackRecipients, to)
} else {
slog.Warn("template balance low notification delivery failed; not sending fallback to avoid duplicates", "to", to, "err", err.Error())
}
}
}
if len(fallbackRecipients) == 0 {
return
}
recipients = fallbackRecipients
}
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
@ -369,6 +407,44 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
remaining = 0
}
if s.notificationEmailService != nil {
fallbackRecipients := make([]string, 0, len(adminEmails))
for _, to := range adminEmails {
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventAccountQuotaAlert,
RecipientEmail: to,
RecipientName: emailRecipientName(to),
SourceType: "account_quota",
SourceID: fmt.Sprintf("%d-%s", accountID, dim.name),
ReminderKey: time.Now().UTC().Format("2006-01-02"),
Variables: map[string]string{
"account_id": strconv.FormatInt(accountID, 10),
"account_name": accountName,
"platform": platform,
"quota_dimension": dimLabel,
"quota_used": fmt.Sprintf("%.2f", used),
"quota_limit": fmt.Sprintf("%.2f", dim.limit),
"quota_remaining": fmt.Sprintf("%.2f", remaining),
"quota_threshold": thresholdDisplay,
},
})
cancel()
if err != nil {
if shouldFallbackNotificationEmail(err) {
slog.Warn("template account quota alert failed; falling back to built-in body", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error())
fallbackRecipients = append(fallbackRecipients, to)
} else {
slog.Warn("template account quota alert delivery failed; not sending fallback to avoid duplicates", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error())
}
}
}
if len(fallbackRecipients) == 0 {
return
}
adminEmails = fallbackRecipients
}
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)

View File

@ -15,16 +15,19 @@ import (
//
// 行为按 tokenType / mimicClaudeCode 分两条路径:
//
// OAuth mimic 路径 (tokenType == "oauth" && mimicClaudeCode):
// 1. body 中 metadata.user_id 派生的 SessionID 是合法 UUID → canonicalize 写入
// 2. 请求 header 中已有合法 UUID → canonicalize 保留
// 3. 否则 → 兜底生成 UUID
// OAuth 路径 (tokenType == "oauth"):
// OAuth 账号本身就是真实 Claude Code 客户端的凭证,可以信任 body 中的
// metadata.user_id 派生 session id。
// 1. metadata.user_id 派生 SessionID 是合法 UUID → canonical 写入
// 2. header 已有合法 UUID → canonical 保留
// 3. mimicClaudeCode == true → 兜底生成新 UUID
// (mimicClaudeCode == false 且无 metadata 时不强制注入)
//
// API key 透传 / 非 mimic 路径:
// - 不从 body 合成 header避免污染客户端原始语义
// - 若客户端在 header 中传入 X-Claude-Code-Session-Id
// 合法 UUID → canonicalize 保留
// 非法值 → 删除(不向上游转发恶意值,符合 UUID 校验承诺
// API key 透传路径 (tokenType != "oauth"):
// - 不从 body metadata 派生 header避免污染客户端原始语义
// - 若客户端在 header 中传入 X-Claude-Code-Session-Id
// 合法 UUID → canonical 保留
// 非法值 → 删除(不向上游转发恶意值
// - 不兜底生成
//
// 安全说明metadata.user_id 由客户端控制ParseMetadataUserID 的正则仅约束字符集,
@ -37,10 +40,10 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string,
req.Header = make(http.Header)
}
isOAuthMimic := tokenType == "oauth" && mimicClaudeCode
isOAuth := tokenType == "oauth"
// OAuth mimic 路径:从 metadata 派生(仅在 mimic 场景写 header)。
if isOAuthMimic {
// OAuth 路径:从 metadata 派生OAuth 凭证可信任)。
if isOAuth {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
if id, err := uuid.Parse(parsed.SessionID); err == nil {
@ -65,9 +68,9 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string,
req.Header.Del("X-Claude-Code-Session-Id")
}
// OAuth mimic 兜底生成(仅 mimic 场景API key 不污染)。
// OAuth mimic 兜底生成(仅 mimic 场景API key/非 mimic 不污染)。
// uuid.NewString() 走 crypto/rand。
if isOAuthMimic {
if isOAuth && mimicClaudeCode {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", uuid.NewString())
}
}

View File

@ -136,15 +136,17 @@ func TestEnsureClaudeCodeSessionID_APIKeyIgnoresMetadata(t *testing.T) {
}
}
// OAuth 但非 mimic 模式也不应该从 metadata 派生 header。
func TestEnsureClaudeCodeSessionID_OAuthNonMimicIgnoresMetadata(t *testing.T) {
// OAuth 路径即使 mimic=false 也应该从 metadata 派生 header
// OAuth 凭证本身就是 Claude Code 类型账号metadata.user_id 可信任。
// 这与 API key 路径不同API key 是任意第三方调用方)。
func TestEnsureClaudeCodeSessionID_OAuthNonMimicDerivesFromMetadata(t *testing.T) {
req := newReq(t)
body := []byte(`{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"` + testValidUUID + `\"}"}}`)
ensureClaudeCodeSessionID(req, body, "oauth", false)
got := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id")
if got != "" {
t.Fatalf("Non-mimic OAuth must NOT derive session-id from metadata, got %q", got)
if got != testValidUUID {
t.Fatalf("OAuth must derive session-id from metadata regardless of mimic flag, got %q want %q", got, testValidUUID)
}
}

View File

@ -1463,6 +1463,24 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context,
func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
siteName := s.siteName(ctx)
if s.emailService.notificationEmailService != nil {
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventContentModerationViolation,
RecipientEmail: log.UserEmail,
RecipientName: emailRecipientName(log.UserEmail),
UserID: contentModerationEmailUserID(log),
SourceType: "content_moderation",
SourceID: contentModerationEmailSourceID(log),
Variables: contentModerationEmailVariables(log, cfg),
}); err == nil {
return nil
} else {
if !shouldFallbackNotificationEmail(err) {
return err
}
slog.Warn("template content moderation violation email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error())
}
}
subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName))
body := buildContentModerationViolationEmailBody(siteName, log, cfg)
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
@ -1470,11 +1488,71 @@ func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *
func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
siteName := s.siteName(ctx)
if s.emailService.notificationEmailService != nil {
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventContentModerationDisabled,
RecipientEmail: log.UserEmail,
RecipientName: emailRecipientName(log.UserEmail),
UserID: contentModerationEmailUserID(log),
SourceType: "content_moderation",
SourceID: contentModerationEmailSourceID(log),
Variables: contentModerationEmailVariables(log, cfg),
}); err == nil {
return nil
} else {
if !shouldFallbackNotificationEmail(err) {
return err
}
slog.Warn("template content moderation disabled email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error())
}
}
subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName))
body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg)
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
}
func contentModerationEmailUserID(log *ContentModerationLog) int64 {
if log == nil || log.UserID == nil {
return 0
}
return *log.UserID
}
func contentModerationEmailSourceID(log *ContentModerationLog) string {
if log == nil || log.ID <= 0 {
return ""
}
return fmt.Sprintf("%d", log.ID)
}
func contentModerationEmailVariables(log *ContentModerationLog, cfg *ContentModerationConfig) map[string]string {
variables := map[string]string{
"triggered_at": time.Now().UTC().Format(time.RFC3339),
"group_name": "-",
"moderation_category": "-",
"moderation_score": "0.000",
"violation_count": "0",
"ban_threshold": "0",
}
if log != nil {
if !log.CreatedAt.IsZero() {
variables["triggered_at"] = log.CreatedAt.UTC().Format(time.RFC3339)
}
if strings.TrimSpace(log.GroupName) != "" {
variables["group_name"] = strings.TrimSpace(log.GroupName)
}
if strings.TrimSpace(log.HighestCategory) != "" {
variables["moderation_category"] = strings.TrimSpace(log.HighestCategory)
}
variables["moderation_score"] = fmt.Sprintf("%.3f", log.HighestScore)
variables["violation_count"] = fmt.Sprintf("%d", log.ViolationCount)
}
if cfg != nil {
variables["ban_threshold"] = fmt.Sprintf("%d", cfg.BanThreshold)
}
return variables
}
func (s *ContentModerationService) siteName(ctx context.Context) string {
if s == nil || s.settingRepo == nil {
return "Sub2API"

View File

@ -401,6 +401,10 @@ const (
SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control"
// SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值)
SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version"
// SettingKeyOpenAICodexUserAgent OpenAI Codex 完整 User-Agent空值使用内置默认
// 当客户端 UA 被识别为浏览器Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值,
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关

View File

@ -21,6 +21,7 @@ type EmailTask struct {
SiteName string
TaskType string // "verify_code" or "password_reset"
ResetURL string // Only used for password_reset task type
Locale string // Optional Accept-Language locale hint
}
// EmailQueueService 异步邮件队列服务
@ -82,13 +83,13 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
switch task.TaskType {
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName, task.Locale); err != nil {
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL, task.Locale); err != nil {
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
@ -99,11 +100,12 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
}
// EnqueueVerifyCode 将验证码发送任务加入队列
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string, locale ...string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypeVerifyCode,
Locale: firstEmailLocale(locale),
}
select {
@ -116,12 +118,13 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string, locale ...string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypePasswordReset,
ResetURL: resetURL,
Locale: firstEmailLocale(locale),
}
select {

View File

@ -94,8 +94,9 @@ type SMTPConfig struct {
// EmailService 邮件服务
type EmailService struct {
settingRepo SettingRepository
cache EmailCache
settingRepo SettingRepository
cache EmailCache
notificationEmailService *NotificationEmailService
}
// NewEmailService 创建邮件服务实例
@ -106,6 +107,28 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
func (s *EmailService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
s.notificationEmailService = notificationEmailService
}
func firstEmailLocale(locales []string) string {
if len(locales) == 0 {
return ""
}
return strings.TrimSpace(locales[0])
}
func emailRecipientName(email string) string {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return ""
}
if at := strings.Index(trimmed, "@"); at > 0 {
return trimmed[:at]
}
return trimmed
}
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{
@ -301,7 +324,7 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
}
// SendVerifyCode 发送验证码邮件
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string, locale ...string) error {
// 检查是否在冷却期内
existing, err := s.cache.GetVerificationCode(ctx, email)
if err == nil && existing != nil {
@ -327,6 +350,26 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
return fmt.Errorf("save verify code: %w", err)
}
if s.notificationEmailService != nil {
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventAuthVerifyCode,
Locale: firstEmailLocale(locale),
RecipientEmail: email,
RecipientName: emailRecipientName(email),
Variables: map[string]string{
"verification_code": code,
"expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)),
},
})
if err == nil {
return nil
}
if !shouldFallbackNotificationEmail(err) {
return err
}
slog.Warn("failed to send templated verification email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err)
}
// 构建邮件内容
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
body := s.buildVerifyCodeEmailBody(code, siteName)
@ -469,7 +512,7 @@ func (s *EmailService) GeneratePasswordResetToken() (string, error) {
}
// SendPasswordResetEmail sends a password reset email with a reset link
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string, locale ...string) error {
var token string
var needSaveToken bool
@ -502,6 +545,26 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
// Build full reset URL with URL-encoded token and email
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
if s.notificationEmailService != nil {
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventAuthPasswordReset,
Locale: firstEmailLocale(locale),
RecipientEmail: email,
RecipientName: emailRecipientName(email),
Variables: map[string]string{
"reset_url": fullResetURL,
"expires_in_minutes": strconv.Itoa(int(passwordResetTokenTTL / time.Minute)),
},
})
if err == nil {
return nil
}
if !shouldFallbackNotificationEmail(err) {
return err
}
slog.Warn("failed to send templated password reset email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err)
}
// Build email content
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
@ -516,7 +579,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string, locale ...string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
slog.Info("password reset email skipped due to cooldown", "email", email)
@ -524,7 +587,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
}
// Send email using core method
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,571 @@
package service
import (
"bufio"
"context"
"errors"
"net"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
)
func TestNotificationEmailPreviewEscapesHTMLAndSanitizesSubject(t *testing.T) {
ctx := context.Background()
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
Event: NotificationEmailEventBalanceLow,
Locale: "en-US,en;q=0.9",
Subject: "Low balance for {{recipient_name}}\r\nInjected",
HTML: `<p>{{recipient_name}}</p><a href="{{recharge_url}}">Recharge</a>`,
Variables: map[string]string{
"recipient_name": `<script>alert("x")</script>`,
"recharge_url": `javascript:alert(1)`,
},
})
require.NoError(t, err)
require.NotContains(t, preview.Subject, "\r")
require.NotContains(t, preview.Subject, "\n")
require.Contains(t, preview.Subject, `Low balance for <script>alert("x")</script>Injected`)
require.Contains(t, preview.HTML, `&lt;script&gt;alert(&#34;x&#34;)&lt;/script&gt;`)
require.NotContains(t, preview.HTML, `javascript:alert(1)`)
require.Contains(t, preview.HTML, `href=""`)
}
func TestNotificationEmailTemplateOverrideAndRestore(t *testing.T) {
ctx := context.Background()
repo := newNotificationEmailMemorySettingRepo()
svc := NewNotificationEmailService(repo, nil)
official, err := svc.GetTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "en")
require.NoError(t, err)
require.False(t, official.IsCustom)
updated, err := svc.UpdateTemplate(
ctx,
NotificationEmailEventBalanceRechargeSuccess,
"zh-Hans",
"充值完成:{{recharge_amount}}",
"<p>{{recipient_name}} 已充值 {{recharge_amount}}</p>",
)
require.NoError(t, err)
require.True(t, updated.IsCustom)
require.Equal(t, "zh", updated.Locale)
require.Equal(t, "充值完成:{{recharge_amount}}", updated.Subject)
require.NotNil(t, updated.UpdatedAt)
restored, err := svc.RestoreOfficialTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "zh")
require.NoError(t, err)
require.False(t, restored.IsCustom)
require.NotEqual(t, updated.Subject, restored.Subject)
_, err = repo.GetValue(ctx, notificationEmailTemplateKey(NotificationEmailEventBalanceRechargeSuccess, "zh"))
require.ErrorIs(t, err, ErrSettingNotFound)
}
func TestNotificationEmailTemplateRejectsUnsupportedPlaceholder(t *testing.T) {
ctx := context.Background()
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
_, err := svc.UpdateTemplate(
ctx,
NotificationEmailEventSubscriptionPurchaseSuccess,
"en",
"Purchased {{not_allowed}}",
"<p>{{subscription_group}}</p>",
)
require.Error(t, err)
require.Contains(t, err.Error(), "unsupported placeholder")
}
func TestNotificationEmailAuthTemplatesAreListedAndPreviewable(t *testing.T) {
ctx := context.Background()
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
infos := svc.ListEventInfos()
events := make(map[string]NotificationEmailEventInfo, len(infos))
for _, info := range infos {
events[info.Event] = info
}
require.Contains(t, events, NotificationEmailEventAuthVerifyCode)
require.Contains(t, events, NotificationEmailEventAuthPasswordReset)
require.False(t, events[NotificationEmailEventAuthVerifyCode].Optional)
require.False(t, events[NotificationEmailEventAuthPasswordReset].Optional)
require.Contains(t, events[NotificationEmailEventAuthVerifyCode].Placeholders, "verification_code")
require.Contains(t, events[NotificationEmailEventAuthPasswordReset].Placeholders, "reset_url")
verifyPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
Event: NotificationEmailEventAuthVerifyCode,
Locale: "zh-CN",
Variables: map[string]string{
"verification_code": "654321",
"expires_in_minutes": "15",
},
})
require.NoError(t, err)
require.Contains(t, verifyPreview.Subject, "邮箱验证码")
require.Contains(t, verifyPreview.HTML, "654321")
resetPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
Event: NotificationEmailEventAuthPasswordReset,
Locale: "en",
Variables: map[string]string{
"reset_url": "https://example.com/reset?token=abc",
"expires_in_minutes": "30",
},
})
require.NoError(t, err)
require.Contains(t, resetPreview.Subject, "Password reset")
require.Contains(t, resetPreview.HTML, "https://example.com/reset?token=abc")
}
func TestNotificationEmailAdditionalEventsAreListedAndPreviewable(t *testing.T) {
ctx := context.Background()
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
infos := svc.ListEventInfos()
events := make(map[string]NotificationEmailEventInfo, len(infos))
for _, info := range infos {
events[info.Event] = info
}
checks := []struct {
event string
placeholder string
}{
{NotificationEmailEventNotificationEmailVerifyCode, "verification_code"},
{NotificationEmailEventAccountQuotaAlert, "account_name"},
{NotificationEmailEventContentModerationViolation, "moderation_category"},
{NotificationEmailEventContentModerationDisabled, "violation_count"},
{NotificationEmailEventOpsAlert, "rule_name"},
{NotificationEmailEventOpsScheduledReport, "report_html"},
}
for _, check := range checks {
info, ok := events[check.event]
require.Truef(t, ok, "expected %s to be listed", check.event)
require.False(t, info.Optional)
require.Contains(t, info.Placeholders, check.placeholder)
preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{Event: check.event, Locale: "zh"})
require.NoError(t, err)
require.NotEmpty(t, preview.Subject)
require.NotEmpty(t, preview.HTML)
}
}
func TestNotificationEmailRawHTMLVariablesAreTrustedOnlyForHTMLPlaceholders(t *testing.T) {
require.True(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "report_html"))
require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "recipient_name"))
require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsAlert, "report_html"))
preview, err := renderNotificationEmail(
NotificationEmailEventOpsScheduledReport,
"Report for {{recipient_name}}",
`<section>{{report_html}}</section><p>{{recipient_name}}</p>`,
map[string]string{
"recipient_name": `<script>alert("x")</script>`,
"report_html": `<p>escaped report</p>`,
},
map[string]string{
"report_html": `<table><tr><td>trusted report</td></tr></table>`,
},
)
require.NoError(t, err)
require.Contains(t, preview.HTML, `<table><tr><td>trusted report</td></tr></table>`)
require.NotContains(t, preview.HTML, `escaped report`)
require.Contains(t, preview.HTML, `&lt;script&gt;alert(&#34;x&#34;)&lt;/script&gt;`)
require.Contains(t, preview.Subject, `<script>alert("x")</script>`)
preview, err = renderNotificationEmail(
NotificationEmailEventOpsScheduledReport,
"Recipient {{recipient_name}}",
`<p>{{recipient_name}}</p>`,
map[string]string{"recipient_name": `<em>escaped</em>`},
map[string]string{"recipient_name": `<strong>raw</strong>`},
)
require.NoError(t, err)
require.Contains(t, preview.HTML, `&lt;em&gt;escaped&lt;/em&gt;`)
require.NotContains(t, preview.HTML, `<strong>raw</strong>`)
}
func TestNotificationEmailFallbackClassification(t *testing.T) {
templateErr := notificationEmailTemplateErr(errors.New("bad template"))
configErr := notificationEmailConfigErr(errors.New("missing email service"))
deliveryErr := notificationEmailDeliveryErr(errors.New("smtp timeout"))
require.True(t, shouldFallbackNotificationEmail(templateErr))
require.True(t, shouldFallbackNotificationEmail(configErr))
require.False(t, shouldFallbackNotificationEmail(deliveryErr))
require.True(t, isNotificationEmailDeliveryError(deliveryErr))
require.False(t, isNotificationEmailDeliveryError(templateErr))
require.False(t, shouldFallbackNotificationEmail(nil))
}
func TestEmailQueueTasksPreserveLocaleHints(t *testing.T) {
queue := &EmailQueueService{taskChan: make(chan EmailTask, 2)}
require.NoError(t, queue.EnqueueVerifyCode("user@example.com", "Sub2API", "zh-CN"))
require.NoError(t, queue.EnqueuePasswordReset("user@example.com", "Sub2API", "https://example.com/reset", "en-US"))
verifyTask := <-queue.taskChan
require.Equal(t, TaskTypeVerifyCode, verifyTask.TaskType)
require.Equal(t, "zh-CN", verifyTask.Locale)
resetTask := <-queue.taskChan
require.Equal(t, TaskTypePasswordReset, resetTask.TaskType)
require.Equal(t, "en-US", resetTask.Locale)
}
func TestOpsScheduledReportDeliverySourceIDIncludesReportIdentity(t *testing.T) {
report := &opsScheduledReport{Name: "日报", ReportType: "daily_summary", Schedule: "0 9 * * *"}
sourceID := opsScheduledReportDeliverySourceID(report)
require.Contains(t, sourceID, "daily_summary")
require.Contains(t, sourceID, "日报")
require.Contains(t, sourceID, "0 9 * * *")
require.NotEqual(t, sourceID, opsScheduledReportDeliverySourceID(&opsScheduledReport{Name: "周报", ReportType: "weekly_summary", Schedule: "0 9 * * 1"}))
require.Equal(t, "scheduled_report", opsScheduledReportDeliverySourceID(nil))
}
func TestNotificationEmailUnsubscribeOnlyAllowsOptionalEvents(t *testing.T) {
ctx := context.Background()
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
token, err := svc.createUnsubscribeToken(ctx, "User@Example.com", NotificationEmailEventBalanceLow)
require.NoError(t, err)
result, err := svc.Unsubscribe(ctx, token)
require.NoError(t, err)
require.True(t, result.Done)
require.Equal(t, NotificationEmailEventBalanceLow, result.Event)
unsubscribed, err := svc.IsUnsubscribed(ctx, "user@example.com", NotificationEmailEventBalanceLow)
require.NoError(t, err)
require.True(t, unsubscribed)
transactionalToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventBalanceRechargeSuccess)
require.NoError(t, err)
_, err = svc.Unsubscribe(ctx, transactionalToken)
require.Error(t, err)
require.Contains(t, err.Error(), "transactional")
authToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventAuthVerifyCode)
require.NoError(t, err)
_, err = svc.Unsubscribe(ctx, authToken)
require.Error(t, err)
require.Contains(t, err.Error(), "transactional")
}
func TestNotificationEmailLocaleMemoryNormalizesAcceptLanguage(t *testing.T) {
ctx := context.Background()
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
svc.RememberRecipientLocale(ctx, 42, "User@Example.com", "zh-CN,zh;q=0.9,en;q=0.8")
require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 42, "user@example.com"))
require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 0, "user@example.com"))
}
func TestNotificationEmailDeliveryKeyUsesShortStableHash(t *testing.T) {
key := notificationEmailDeliveryKey(
NotificationEmailEventSubscriptionExpiryReminder,
"user_subscription",
"1234567890",
"User@Example.com",
"7d",
)
require.NotEmpty(t, key)
require.LessOrEqual(t, len(key), 100)
require.True(t, strings.HasPrefix(key, notificationEmailDeliveryKeyPrefix+"v2:"))
require.Equal(t, key, notificationEmailDeliveryKey(
NotificationEmailEventSubscriptionExpiryReminder,
"user_subscription",
"1234567890",
"user@example.com",
"7d",
))
require.NotEqual(t, key, notificationEmailDeliveryKey(
NotificationEmailEventSubscriptionExpiryReminder,
"user_subscription",
"1234567890",
"user@example.com",
"3d",
))
legacyKey := legacyNotificationEmailDeliveryKey(
NotificationEmailEventSubscriptionExpiryReminder,
"user_subscription",
"1234567890",
"user@example.com",
"7d",
)
require.Greater(t, len(legacyKey), 100)
}
func TestNotificationEmailPreferenceKeyUsesShortStableHashAndReadsLegacyKey(t *testing.T) {
ctx := context.Background()
repo := newNotificationEmailMemorySettingRepo()
svc := NewNotificationEmailService(repo, nil)
key := notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "User@Example.com")
require.NotEmpty(t, key)
require.LessOrEqual(t, len(key), 100)
require.True(t, strings.HasPrefix(key, notificationEmailPreferenceKeyPrefix+"v2:"))
require.Equal(t, key, notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com"))
legacyKey := legacyNotificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com")
require.Greater(t, len(legacyKey), 100)
require.NoError(t, repo.Set(ctx, legacyKey, "unsubscribed"))
unsubscribed, err := svc.IsUnsubscribed(ctx, "User@Example.com", NotificationEmailEventSubscriptionExpiryReminder)
require.NoError(t, err)
require.True(t, unsubscribed)
}
func TestNotificationEmailSendDeduplicatesSubscriptionExpiryReminder(t *testing.T) {
ctx := context.Background()
repo := newNotificationEmailMemorySettingRepo()
smtpServer := startNotificationEmailTestSMTPServer(t)
require.NoError(t, repo.SetMultiple(ctx, smtpServer.settings()))
emailSvc := NewEmailService(repo, nil)
svc := NewNotificationEmailService(repo, emailSvc)
input := NotificationEmailSendInput{
Event: NotificationEmailEventSubscriptionExpiryReminder,
RecipientEmail: "User@Example.com",
RecipientName: "User",
UserID: 42,
SourceType: "user_subscription",
SourceID: "1234567890",
ReminderKey: "7d",
Variables: map[string]string{
"subscription_group": "Codex",
"expiry_time": "2026-05-27 12:00",
"days_remaining": "7",
},
}
require.NoError(t, svc.Send(ctx, input))
require.Equal(t, int64(1), smtpServer.messageCount())
key := notificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey)
require.LessOrEqual(t, len(key), 100)
_, err := repo.GetValue(ctx, key)
require.NoError(t, err)
require.NoError(t, svc.Send(ctx, input))
require.Equal(t, int64(1), smtpServer.messageCount())
}
func TestNotificationEmailSendRespectsLegacyDeliveryKey(t *testing.T) {
ctx := context.Background()
repo := newNotificationEmailMemorySettingRepo()
svc := NewNotificationEmailService(repo, nil)
input := NotificationEmailSendInput{
Event: NotificationEmailEventSubscriptionExpiryReminder,
RecipientEmail: "user@example.com",
SourceType: "user_subscription",
SourceID: "1234567890",
ReminderKey: "7d",
}
legacyKey := legacyNotificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey)
require.NoError(t, repo.Set(ctx, legacyKey, "sent"))
require.NoError(t, svc.Send(ctx, input))
}
type notificationEmailMemorySettingRepo struct {
mu sync.RWMutex
values map[string]string
}
func newNotificationEmailMemorySettingRepo() *notificationEmailMemorySettingRepo {
return &notificationEmailMemorySettingRepo{values: make(map[string]string)}
}
func (r *notificationEmailMemorySettingRepo) Get(_ context.Context, key string) (*Setting, error) {
r.mu.RLock()
defer r.mu.RUnlock()
value, ok := r.values[key]
if !ok {
return nil, ErrSettingNotFound
}
return &Setting{Key: key, Value: value}, nil
}
func (r *notificationEmailMemorySettingRepo) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key)
if err != nil {
return "", err
}
return setting.Value, nil
}
func (r *notificationEmailMemorySettingRepo) Set(_ context.Context, key, value string) error {
r.mu.Lock()
defer r.mu.Unlock()
r.values[key] = value
return nil
}
func (r *notificationEmailMemorySettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
r.mu.RLock()
defer r.mu.RUnlock()
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := r.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (r *notificationEmailMemorySettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
r.mu.Lock()
defer r.mu.Unlock()
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *notificationEmailMemorySettingRepo) GetAll(_ context.Context) (map[string]string, error) {
r.mu.RLock()
defer r.mu.RUnlock()
out := make(map[string]string, len(r.values))
for key, value := range r.values {
out[key] = value
}
return out, nil
}
func (r *notificationEmailMemorySettingRepo) Delete(_ context.Context, key string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.values[key]; !ok {
return ErrSettingNotFound
}
delete(r.values, key)
return nil
}
func TestNotificationEmailMemorySettingRepoSatisfiesInterface(t *testing.T) {
var _ SettingRepository = (*notificationEmailMemorySettingRepo)(nil)
require.False(t, strings.Contains(notificationEmailPreferenceKey(NotificationEmailEventBalanceLow, "User@Example.com"), "User@Example.com"))
}
type notificationEmailTestSMTPServer struct {
listener net.Listener
wg sync.WaitGroup
messages atomic.Int64
}
func startNotificationEmailTestSMTPServer(t *testing.T) *notificationEmailTestSMTPServer {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
server := &notificationEmailTestSMTPServer{listener: listener}
server.wg.Add(1)
go server.serve()
t.Cleanup(server.close)
return server
}
func (s *notificationEmailTestSMTPServer) settings() map[string]string {
host, port, _ := net.SplitHostPort(s.listener.Addr().String())
return map[string]string{
SettingKeySMTPHost: host,
SettingKeySMTPPort: port,
SettingKeySMTPUsername: "user",
SettingKeySMTPPassword: "password",
SettingKeySMTPFrom: "noreply@example.com",
SettingKeySMTPFromName: "Sub2API",
SettingKeySMTPUseTLS: "false",
}
}
func (s *notificationEmailTestSMTPServer) messageCount() int64 {
return s.messages.Load()
}
func (s *notificationEmailTestSMTPServer) close() {
_ = s.listener.Close()
s.wg.Wait()
}
func (s *notificationEmailTestSMTPServer) serve() {
defer s.wg.Done()
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
s.handleConn(conn)
}
}
func (s *notificationEmailTestSMTPServer) handleConn(conn net.Conn) {
defer func() { _ = conn.Close() }()
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
writeLine := func(line string) bool {
if _, err := rw.WriteString(line + "\r\n"); err != nil {
return false
}
return rw.Flush() == nil
}
if !writeLine("220 localhost ESMTP") {
return
}
for {
line, err := rw.ReadString('\n')
if err != nil {
return
}
cmd := strings.ToUpper(strings.TrimRight(line, "\r\n"))
switch {
case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"):
if _, err := rw.WriteString("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil {
return
}
if err := rw.Flush(); err != nil {
return
}
case strings.HasPrefix(cmd, "AUTH"):
if !writeLine("235 2.7.0 Authentication successful") {
return
}
case strings.HasPrefix(cmd, "MAIL FROM:"):
if !writeLine("250 2.1.0 OK") {
return
}
case strings.HasPrefix(cmd, "RCPT TO:"):
if !writeLine("250 2.1.5 OK") {
return
}
case strings.HasPrefix(cmd, "DATA"):
if !writeLine("354 End data with <CR><LF>.<CR><LF>") {
return
}
for {
dataLine, err := rw.ReadString('\n')
if err != nil {
return
}
if strings.TrimRight(dataLine, "\r\n") == "." {
break
}
}
s.messages.Add(1)
if !writeLine("250 2.0.0 OK") {
return
}
case strings.HasPrefix(cmd, "QUIT"):
_ = writeLine("221 2.0.0 Bye")
return
default:
if !writeLine("250 OK") {
return
}
}
}
}

View File

@ -0,0 +1,428 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// forwardResponsesViaRawChatCompletions serves /v1/responses clients through an
// upstream that only supports /v1/chat/completions.
func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
var responsesReq apicompat.ResponsesRequest
if err := json.Unmarshal(body, &responsesReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Failed to parse request body",
},
})
return nil, fmt.Errorf("parse responses request: %w", err)
}
originalModel := strings.TrimSpace(responsesReq.Model)
if originalModel == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "model is required",
},
})
return nil, fmt.Errorf("missing model in request")
}
clientStream := responsesReq.Stream
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
serviceTier := extractOpenAIServiceTierFromBody(body)
chatReq, err := apicompat.ResponsesToChatCompletionsRequest(&responsesReq)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": err.Error(),
},
})
return nil, fmt.Errorf("convert responses to chat completions: %w", err)
}
billingModel := resolveOpenAIForwardModel(account, originalModel, "")
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
chatReq.Model = upstreamModel
if clientStream {
chatReq.StreamOptions = &apicompat.ChatStreamOptions{IncludeUsage: true}
}
chatBody, err := json.Marshal(chatReq)
if err != nil {
return nil, fmt.Errorf("marshal chat completions fallback request: %w", err)
}
chatBody, err = s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, chatBody)
if err != nil {
var blocked *OpenAIFastBlockedError
if errors.As(err, &blocked) {
writeOpenAIFastPolicyBlockedResponse(c, blocked)
}
return nil, err
}
if serviceTier == nil {
serviceTier = extractOpenAIServiceTierFromBody(chatBody)
}
logger.L().Debug("openai responses: forwarding via raw chat completions",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
)
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(chatBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, chatBody)
}
if clientStream {
return s.streamChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
return s.bufferChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
func (s *OpenAIGatewayService) bufferChatCompletionsAsResponses(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Failed to read upstream response",
},
})
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
var ccResp apicompat.ChatCompletionsResponse
if err := json.Unmarshal(respBody, &ccResp); err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Failed to parse upstream response",
},
})
return nil, fmt.Errorf("parse chat completions response: %w", err)
}
responsesResp := apicompat.ChatCompletionsResponseToResponses(&ccResp, originalModel)
usage := OpenAIUsage{}
if parsed, ok := extractOpenAIUsageFromJSONBytes(respBody); ok {
usage = parsed
}
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
func (s *OpenAIGatewayService) streamChatCompletionsAsResponses(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
headersWritten := false
writeStreamHeaders := func() {
if headersWritten {
return
}
headersWritten = true
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
state := apicompat.NewChatCompletionsToResponsesStreamState(originalModel)
var usage OpenAIUsage
var firstTokenMs *int
clientDisconnected := false
sawDone := false
writeEvents := func(events []apicompat.ResponsesStreamEvent) {
if clientDisconnected || len(events) == 0 {
return
}
writeStreamHeaders()
for _, event := range events {
sse, err := apicompat.ResponsesEventToSSE(event)
if err != nil {
logger.L().Warn("openai responses chat fallback: failed to marshal stream event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Debug("openai responses chat fallback: client disconnected, continuing to drain upstream for billing",
zap.Error(err),
zap.String("request_id", requestID),
)
return
}
}
c.Writer.Flush()
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
for scanner.Scan() {
line := scanner.Text()
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
payload = strings.TrimSpace(payload)
if payload == "" {
continue
}
if payload == "[DONE]" {
sawDone = true
break
}
if u := extractCCStreamUsage(payload); u != nil {
usage = *u
}
var chunk apicompat.ChatCompletionsChunk
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
logger.L().Warn("openai responses chat fallback: failed to parse chat stream chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if firstTokenMs == nil && !isOpenAIChatUsageOnlyStreamChunk(payload) && chatChunkStartsResponsesOutput(&chunk) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
writeEvents(apicompat.ChatCompletionsChunkToResponsesEvents(&chunk, state))
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai responses chat fallback: stream read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, fmt.Errorf("stream usage incomplete: %w", err)
}
writeEvents(apicompat.FinalizeChatCompletionsResponsesStream(state))
if !clientDisconnected {
writeStreamHeaders()
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
}
if !clientDisconnected {
c.Writer.Flush()
}
}
if !sawDone {
logger.L().Debug("openai responses chat fallback: upstream stream ended without done sentinel",
zap.String("request_id", requestID),
)
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func chatChunkStartsResponsesOutput(chunk *apicompat.ChatCompletionsChunk) bool {
if chunk == nil {
return false
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != nil || choice.Delta.ReasoningContent != nil || len(choice.Delta.ToolCalls) > 0 {
return true
}
}
return false
}

View File

@ -0,0 +1,145 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestForwardResponses_ForceChatCompletionsRoutesNonStreamingToChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_chat_json"}},
Body: io.NopCloser(strings.NewReader(
`{"id":"chatcmpl_json","object":"chat.completion","model":"gpt-5.4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5,"prompt_tokens_details":{"cached_tokens":1}}}`,
)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
require.Equal(t, "response", gjson.Get(rec.Body.String(), "object").String())
require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String())
require.Equal(t, 3, result.Usage.InputTokens)
require.Equal(t, 2, result.Usage.OutputTokens)
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
require.False(t, result.Stream)
}
func TestForwardResponses_ForceChatCompletionsRoutesStreamingToChatCompletions(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":true}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`,
"",
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"he"},"finish_reason":null}]}`,
"",
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"llo"},"finish_reason":null}]}`,
"",
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
"",
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_resp_chat_stream"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
require.Contains(t, rec.Body.String(), "event: response.output_text.delta")
require.Contains(t, rec.Body.String(), `"delta":"he"`)
require.Contains(t, rec.Body.String(), "event: response.completed")
require.Contains(t, rec.Body.String(), `"input_tokens":4`)
require.Contains(t, rec.Body.String(), "data: [DONE]")
require.Equal(t, 4, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
require.True(t, result.Stream)
require.NotNil(t, result.FirstTokenMs)
}
func TestForwardResponses_AutoSupportedAccountStillUsesResponsesEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_native"}},
Body: io.NopCloser(strings.NewReader(
`{"id":"resp_native","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"ok"}],"status":"completed"}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}`,
)),
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
account.Extra = map[string]any{
openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeAuto),
openai_compat.ExtraKeyResponsesSupported: true,
}
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "http://upstream.example/v1/responses", upstream.lastReq.URL.String())
require.True(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
require.False(t, gjson.GetBytes(upstream.lastBody, "messages").Exists())
require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String())
}
func forceChatResponsesFallbackAccount() *Account {
account := rawChatCompletionsTestAccount()
account.Extra = map[string]any{
openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeForceChatCompletions),
}
return account
}

View File

@ -24,6 +24,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
@ -2018,6 +2019,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.forwardResponsesViaRawChatCompletions(ctx, c, account, body)
}
compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body)
setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)
@ -3231,6 +3237,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
req.Header.Set("user-agent", codexCLIUserAgent)
}
// 浏览器型 UA 兜底:仅 OAuthChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器
// Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA避免 Cloudflare 触发 JS 质询。
s.overrideBrowserUserAgent(ctx, account, req)
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
@ -3947,6 +3957,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("user-agent", codexCLIUserAgent)
}
// 浏览器型 UA 兜底:仅 OAuthChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器
// Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA避免 Cloudflare 触发 JS 质询。
s.overrideBrowserUserAgent(ctx, account, req)
// Ensure required headers exist
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
@ -3955,6 +3969,30 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil
}
// overrideBrowserUserAgent 检查请求的最终 user-agent若为浏览器 UA 则替换为后台配置的 Codex UA。
// 用于规避 Cloudflare 对浏览器型 UA 在 ChatGPT 内部接口上的访问质询。
// 影响范围严格限定:仅 OAuthCodex/ChatGPT 内部接口账号生效API Key 等其他账号原样透传。
// 仅在识别为浏览器Mozilla/...)时改写,其他 CLI/工具 UA 不动。
func (s *OpenAIGatewayService) overrideBrowserUserAgent(ctx context.Context, account *Account, req *http.Request) {
if req == nil || account == nil {
return
}
if account.Type != AccountTypeOAuth {
return
}
currentUA := req.Header.Get("user-agent")
if !openai.IsBrowserUserAgent(currentUA) {
return
}
codexUA := DefaultOpenAICodexUserAgent
if s != nil && s.settingService != nil {
if v := strings.TrimSpace(s.settingService.GetOpenAICodexUserAgent(ctx)); v != "" {
codexUA = v
}
}
req.Header.Set("user-agent", codexUA)
}
func (s *OpenAIGatewayService) handleErrorResponse(
ctx context.Context,
resp *http.Response,

View File

@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
tool, _ = sjson.SetBytes(tool, "action", action)
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
if shouldPassOpenAIImagesN(toolModel, parsed.N) {
tool, _ = sjson.SetBytes(tool, "n", parsed.N)
}
for _, field := range []struct {
path string
@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
return req, nil
}
func shouldPassOpenAIImagesN(model string, n int) bool {
if n <= 1 {
return false
}
return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3")
}
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
if gjson.GetBytes(payload, "type").String() != "response.completed" {
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
account.Type,
len(parsed.Uploads),
)
if parsed.N > 1 {
logger.LegacyPrintf(
"service.openai_gateway",
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
parsed.N,
requestModel,
parsed.Endpoint,
)
}
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
defer releaseUpstreamCtx()

View File

@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string)
return openAIImageTestSSEEvent{}, false
}
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
"X-Request-Id": []string{"req_img_123"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
require.NotNil(t, result)
require.Equal(t, "gpt-image-2", result.Model)
require.Equal(t, "gpt-image-2", result.UpstreamModel)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 3, result.ImageCount)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 22, result.Usage.OutputTokens)
require.Equal(t, 7, result.Usage.ImageOutputTokens)
@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int())
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3)
require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String())
require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String())
require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
}
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: "gpt-image-2",
@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
require.NoError(t, err)
require.NotNil(t, body)
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int())
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
}
func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: "dall-e-3",
Prompt: "draw a cat",
N: 2,
}
body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3")
require.NoError(t, err)
require.NotNil(t, body)
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String())
}
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,

View File

@ -686,6 +686,21 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
if !s.emailLimiter.Allow(time.Now().UTC()) {
continue
}
if s.emailService.notificationEmailService != nil {
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventOpsAlert,
RecipientEmail: addr,
RecipientName: emailRecipientName(addr),
SourceType: "ops_alert",
SourceID: fmt.Sprintf("%d", event.ID),
Variables: opsAlertEmailVariables(rule, event),
}); err == nil {
anySent = true
continue
} else if !shouldFallbackNotificationEmail(err) {
continue
}
}
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
// Ignore per-recipient failures; continue best-effort.
continue
@ -699,6 +714,46 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
return anySent
}
func opsAlertEmailVariables(rule *OpsAlertRule, event *OpsAlertEvent) map[string]string {
variables := map[string]string{
"rule_name": "-",
"severity": "-",
"alert_status": "-",
"metric_type": "-",
"operator": "-",
"metric_value": "-",
"threshold_value": "-",
"triggered_at": time.Now().UTC().Format(time.RFC3339),
"alert_description": "-",
}
if rule != nil {
variables["rule_name"] = strings.TrimSpace(rule.Name)
variables["severity"] = strings.TrimSpace(rule.Severity)
variables["metric_type"] = strings.TrimSpace(rule.MetricType)
variables["operator"] = strings.TrimSpace(rule.Operator)
variables["threshold_value"] = fmt.Sprintf("%.2f", rule.Threshold)
if strings.TrimSpace(rule.Description) != "" {
variables["alert_description"] = strings.TrimSpace(rule.Description)
}
}
if event != nil {
variables["alert_status"] = strings.TrimSpace(event.Status)
if event.MetricValue != nil {
variables["metric_value"] = fmt.Sprintf("%.2f", *event.MetricValue)
}
if event.ThresholdValue != nil {
variables["threshold_value"] = fmt.Sprintf("%.2f", *event.ThresholdValue)
}
if !event.FiredAt.IsZero() {
variables["triggered_at"] = event.FiredAt.UTC().Format(time.RFC3339)
}
if strings.TrimSpace(event.Description) != "" {
variables["alert_description"] = strings.TrimSpace(event.Description)
}
}
return variables
}
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
if rule == nil || event == nil {
return ""

View File

@ -337,6 +337,7 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
}
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
templateVariables := opsScheduledReportEmailVariables(report, now)
attempts := 0
for _, to := range recipients {
@ -345,6 +346,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
continue
}
attempts++
if s.emailService.notificationEmailService != nil {
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventOpsScheduledReport,
RecipientEmail: addr,
RecipientName: emailRecipientName(addr),
SourceType: "ops_scheduled_report",
SourceID: opsScheduledReportDeliverySourceID(report),
ReminderKey: now.UTC().Format("2006-01-02T15:04"),
Variables: templateVariables,
RawHTMLVariables: map[string]string{
"report_html": content,
},
}); err == nil {
continue
} else if !shouldFallbackNotificationEmail(err) {
continue
}
}
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
// Ignore per-recipient failures; continue best-effort.
continue
@ -353,6 +372,46 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
return attempts, nil
}
func opsScheduledReportDeliverySourceID(report *opsScheduledReport) string {
if report == nil {
return "scheduled_report"
}
parts := []string{
strings.TrimSpace(report.ReportType),
strings.TrimSpace(report.Name),
strings.TrimSpace(report.Schedule),
}
joined := strings.Trim(strings.Join(parts, ":"), ":")
if joined == "" {
return "scheduled_report"
}
return joined
}
func opsScheduledReportEmailVariables(report *opsScheduledReport, now time.Time) map[string]string {
end := now.UTC()
start := end
name := "Ops report"
reportType := "scheduled_report"
if report != nil {
if strings.TrimSpace(report.Name) != "" {
name = strings.TrimSpace(report.Name)
}
if strings.TrimSpace(report.ReportType) != "" {
reportType = strings.TrimSpace(report.ReportType)
}
if report.TimeRange > 0 {
start = end.Add(-report.TimeRange)
}
}
return map[string]string{
"report_name": name,
"report_type": reportType,
"report_start_time": start.Format(time.RFC3339),
"report_end_time": end.Format(time.RFC3339),
}
}
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
if s == nil || s.opsService == nil || report == nil {
return "", fmt.Errorf("service not initialized")

View File

@ -310,9 +310,87 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde
"creditedAmount": o.Amount,
"payAmount": o.PayAmount,
})
s.dispatchPaymentFulfillmentNotification(o, auditAction)
return nil
}
func (s *PaymentService) dispatchPaymentFulfillmentNotification(o *dbent.PaymentOrder, auditAction string) {
if s == nil || s.notificationEmailService == nil || o == nil {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
defer cancel()
var err error
switch auditAction {
case "RECHARGE_SUCCESS":
err = s.sendBalanceRechargeSuccessNotification(ctx, o)
case "SUBSCRIPTION_SUCCESS":
err = s.sendSubscriptionPurchaseSuccessNotification(ctx, o)
default:
return
}
if err != nil {
slog.Warn("payment fulfillment notification email failed", "order_id", o.ID, "action", auditAction, "err", err.Error())
}
}()
}
func (s *PaymentService) sendBalanceRechargeSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error {
currentBalance := ""
if s.userRepo != nil {
if user, err := s.userRepo.GetByID(ctx, o.UserID); err == nil && user != nil {
currentBalance = fmt.Sprintf("%.2f", user.Balance)
}
}
return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventBalanceRechargeSuccess,
RecipientEmail: o.UserEmail,
RecipientName: firstNonEmpty(o.UserName, o.UserEmail),
UserID: o.UserID,
SourceType: "payment_order",
SourceID: strconv.FormatInt(o.ID, 10),
Variables: map[string]string{
"recharge_amount": fmt.Sprintf("%.2f", o.Amount),
"current_balance": currentBalance,
"order_id": strconv.FormatInt(o.ID, 10),
},
})
}
func (s *PaymentService) sendSubscriptionPurchaseSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error {
variables := map[string]string{
"subscription_group": "Subscription",
"subscription_days": "",
"expiry_time": "",
"order_id": strconv.FormatInt(o.ID, 10),
}
if o.SubscriptionDays != nil {
variables["subscription_days"] = strconv.Itoa(*o.SubscriptionDays)
}
if o.SubscriptionGroupID != nil {
if s.groupRepo != nil {
if group, err := s.groupRepo.GetByID(ctx, *o.SubscriptionGroupID); err == nil && group != nil && strings.TrimSpace(group.Name) != "" {
variables["subscription_group"] = group.Name
}
}
if s.subscriptionSvc != nil {
if sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID); err == nil && sub != nil {
variables["expiry_time"] = sub.ExpiresAt.Format("2006-01-02 15:04")
}
}
}
return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventSubscriptionPurchaseSuccess,
RecipientEmail: o.UserEmail,
RecipientName: firstNonEmpty(o.UserName, o.UserEmail),
UserID: o.UserID,
SourceType: "payment_order",
SourceID: strconv.FormatInt(o.ID, 10),
Variables: variables,
})
}
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {

View File

@ -48,6 +48,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if user.Status != payment.EntityStatusActive {
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
}
if s.notificationEmailService != nil {
s.notificationEmailService.RememberRecipientLocale(ctx, req.UserID, user.Email, req.Locale)
}
orderAmount := req.Amount
limitAmount := req.Amount
if plan != nil {

View File

@ -83,6 +83,7 @@ type CreateOrderRequest struct {
PaymentSource string
OrderType string
PlanID int64
Locale string
}
type CreateOrderResponse struct {
@ -174,18 +175,19 @@ type TopUserStat struct {
// --- Service ---
type PaymentService struct {
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
affiliateService *AffiliateService
providerMu sync.Mutex
providersLoaded bool
entClient *dbent.Client
registry *payment.Registry
loadBalancer payment.LoadBalancer
redeemService *RedeemService
subscriptionSvc *SubscriptionService
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
affiliateService *AffiliateService
notificationEmailService *NotificationEmailService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
@ -194,6 +196,10 @@ func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, load
return svc
}
func (s *PaymentService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
s.notificationEmailService = notificationEmailService
}
// --- Provider Registry ---
// EnsureProviders lazily initializes the provider registry on first call.

View File

@ -128,6 +128,19 @@ const antigravityUserAgentVersionCacheTTL = 60 * time.Second
const antigravityUserAgentVersionErrorTTL = 5 * time.Second
const antigravityUserAgentVersionDBTimeout = 5 * time.Second
// DefaultOpenAICodexUserAgent OpenAI Codex 默认 User-Agent用于规避 Cloudflare 对浏览器 UA 的质询)
const DefaultOpenAICodexUserAgent = "codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)"
// cachedOpenAICodexUserAgent 缓存 OpenAI Codex UA进程内缓存60s TTL
type cachedOpenAICodexUserAgent struct {
value string
expiresAt int64 // unix nano
}
const openAICodexUserAgentCacheTTL = 60 * time.Second
const openAICodexUserAgentErrorTTL = 5 * time.Second
const openAICodexUserAgentDBTimeout = 5 * time.Second
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
@ -148,6 +161,8 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
antigravityUAVersionSF singleflight.Group
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
openAICodexUASF singleflight.Group
}
type ProviderDefaultGrantSettings struct {
@ -907,6 +922,55 @@ func (s *SettingService) GetAntigravityUserAgentVersion(ctx context.Context) str
return fallback
}
// GetOpenAICodexUserAgent 返回 OpenAI Codex 上游请求使用的 User-Agent。
// 后台设置优先;为空时回退到内置默认值。
func (s *SettingService) GetOpenAICodexUserAgent(ctx context.Context) string {
fallback := DefaultOpenAICodexUserAgent
if s == nil || s.settingRepo == nil {
return fallback
}
if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value
}
}
result, _, _ := s.openAICodexUASF.Do("openai_codex_user_agent", func() (any, error) {
if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value, nil
}
}
if ctx == nil {
ctx = context.Background()
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAICodexUserAgentDBTimeout)
defer cancel()
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpenAICodexUserAgent)
if err != nil && !errors.Is(err, ErrSettingNotFound) {
slog.Warn("failed to get openai codex user agent setting", "error", err)
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
value: fallback,
expiresAt: time.Now().Add(openAICodexUserAgentErrorTTL).UnixNano(),
})
return fallback, nil
}
ua := strings.TrimSpace(value)
if ua == "" {
ua = fallback
}
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
value: ua,
expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(),
})
return ua, nil
})
if ua, ok := result.(string); ok && ua != "" {
return ua
}
return fallback
}
// SetOnUpdateCallback sets a callback function to be called when settings are updated
// This is used for cache invalidation (e.g., HTML cache in frontend server)
func (s *SettingService) SetOnUpdateCallback(callback func()) {
@ -1706,6 +1770,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent)
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
@ -1788,6 +1853,15 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
version: antigravityUserAgentVersion,
expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(),
})
s.openAICodexUASF.Forget("openai_codex_user_agent")
codexUA := strings.TrimSpace(settings.OpenAICodexUserAgent)
if codexUA == "" {
codexUA = DefaultOpenAICodexUserAgent
}
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
value: codexUA,
expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(),
})
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: settings.OpenAIAdvancedSchedulerEnabled,
@ -2529,6 +2603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
SettingKeyRewriteMessageCacheControl: strconv.FormatBool(s.defaultRewriteMessageCacheControl()),
SettingKeyAntigravityUserAgentVersion: "",
SettingKeyOpenAICodexUserAgent: "",
SettingPaymentVisibleMethodAlipaySource: "",
SettingPaymentVisibleMethodWxpaySource: "",
SettingPaymentVisibleMethodAlipayEnabled: "false",
@ -3041,6 +3116,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl()
}
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent])
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {

View File

@ -193,6 +193,7 @@ type SystemSettings struct {
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl默认 false
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control默认 false
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent空值使用内置默认
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟

View File

@ -2,18 +2,23 @@ package service
import (
"context"
"fmt"
"log"
"strconv"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// SubscriptionExpiryService periodically updates expired subscription status.
type SubscriptionExpiryService struct {
userSubRepo UserSubscriptionRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
userSubRepo UserSubscriptionRepository
notificationEmailService *NotificationEmailService
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
@ -24,6 +29,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv
}
}
func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
s.notificationEmailService = notificationEmailService
}
func (s *SubscriptionExpiryService) Start() {
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
return
@ -68,4 +77,50 @@ func (s *SubscriptionExpiryService) runOnce() {
if updated > 0 {
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
}
s.sendExpiryReminders(ctx)
}
func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil {
return
}
for page := 1; ; page++ {
subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc")
if err != nil {
log.Printf("[SubscriptionExpiry] List active subscriptions for reminder failed: %v", err)
return
}
for i := range subs {
s.sendExpiryReminderIfDue(ctx, &subs[i])
}
if pag == nil || page >= pag.Pages || len(subs) == 0 {
return
}
}
}
func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) {
if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" {
return
}
daysRemaining := sub.DaysRemaining()
if daysRemaining != 7 && daysRemaining != 3 && daysRemaining != 1 {
return
}
if err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventSubscriptionExpiryReminder,
RecipientEmail: sub.User.Email,
RecipientName: firstNonEmpty(sub.User.Username, sub.User.Email),
UserID: sub.UserID,
SourceType: "user_subscription",
SourceID: strconv.FormatInt(sub.ID, 10),
ReminderKey: fmt.Sprintf("%dd", daysRemaining),
Variables: map[string]string{
"subscription_group": sub.Group.Name,
"expiry_time": sub.ExpiresAt.Format("2006-01-02 15:04"),
"days_remaining": strconv.Itoa(daysRemaining),
},
}); err != nil {
log.Printf("[SubscriptionExpiry] Send expiry reminder failed: subscription=%d user=%d err=%v", sub.ID, sub.UserID, err)
}
}

View File

@ -517,7 +517,7 @@ func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMe
}
// SendVerifyCode sends an email verification code for TOTP operations
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64, locale ...string) error {
// Check if email verification is enabled
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
@ -533,5 +533,5 @@ func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
siteName := s.settingService.GetSiteName(ctx)
// Send verification code via queue
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName, firstEmailLocale(locale))
}

View File

@ -324,6 +324,30 @@ func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64,
return stats, nil
}
// GetAPIKeyDailyUsage returns daily usage stats for a user's API key.
func (s *UsageService) GetAPIKeyDailyUsage(ctx context.Context, userID, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.APIKeyDailyUsagePoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, "day", userID, apiKeyID, 0, 0, "", nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get api key daily usage: %w", err)
}
points := make([]usagestats.APIKeyDailyUsagePoint, 0, len(trend))
for _, row := range trend {
points = append(points, usagestats.APIKeyDailyUsagePoint{
Date: row.Date,
Requests: row.Requests,
InputTokens: row.InputTokens,
OutputTokens: row.OutputTokens,
CacheReadTokens: row.CacheReadTokens,
CacheWriteTokens: row.CacheCreationTokens,
TotalTokens: row.TotalTokens,
Cost: row.Cost,
ActualCost: row.ActualCost,
})
}
return points, nil
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)

View File

@ -1122,7 +1122,7 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
}
// SendNotifyEmailCode sends a verification code to the extra notification email.
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache, locale ...string) error {
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
return err
}
@ -1134,7 +1134,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
// Send email first — if SMTP fails, don't write cache or increment counters,
// so the user is not locked out by cooldown/rate-limit for a code they never received.
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
if err := s.sendNotifyVerifyEmail(ctx, emailService, userID, email, code, firstEmailLocale(locale)); err != nil {
return err
}
@ -1180,13 +1180,33 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
}
// sendNotifyVerifyEmail builds and sends the verification email.
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, userID int64, email, code, locale string) error {
siteName := "Sub2API"
if s.settingRepo != nil {
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
siteName = name
}
}
if emailService.notificationEmailService != nil {
if err := emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
Event: NotificationEmailEventNotificationEmailVerifyCode,
Locale: locale,
RecipientEmail: email,
RecipientName: emailRecipientName(email),
UserID: userID,
Variables: map[string]string{
"verification_code": code,
"expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)),
},
}); err == nil {
return nil
} else {
if !shouldFallbackNotificationEmail(err) {
return err
}
slog.Warn("template notification email verification failed; falling back to built-in body", "recipient_hash", notificationEmailHash(email), "err", err.Error())
}
}
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
body := buildNotifyVerifyEmailBody(code, siteName)
return emailService.SendEmail(ctx, email, subject, body)

View File

@ -154,8 +154,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
}
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService {
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
svc.SetNotificationEmailService(notificationEmailService)
svc.Start()
return svc
}
@ -484,6 +485,7 @@ var ProviderSet = wire.NewSet(
ProvideOpsCleanupService,
ProvideOpsScheduledReportService,
NewEmailService,
NewNotificationEmailService,
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
@ -520,7 +522,7 @@ var ProviderSet = wire.NewSet(
NewContentModerationService,
NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,
ProvidePaymentService,
ProvidePaymentOrderExpiryService,
ProvideBalanceNotifyService,
ProvideWindsurfAuthService,
@ -648,8 +650,17 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
}
// ProvideBalanceNotifyService creates BalanceNotifyService
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository, notificationEmailService *NotificationEmailService) *BalanceNotifyService {
svc := NewBalanceNotifyService(emailService, settingRepo, accountRepo)
svc.SetNotificationEmailService(notificationEmailService)
return svc
}
// ProvidePaymentService creates PaymentService and attaches notification email delivery.
func ProvidePaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService, notificationEmailService *NotificationEmailService) *PaymentService {
svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo, affiliateService)
svc.SetNotificationEmailService(notificationEmailService)
return svc
}
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.

View File

@ -5,6 +5,45 @@
import { config } from '@vue/test-utils'
import { vi } from 'vitest'
function createMemoryStorage(): Storage {
const values = new Map<string, string>()
return {
get length() {
return values.size
},
clear() {
values.clear()
},
getItem(key: string) {
return values.has(key) ? values.get(key)! : null
},
key(index: number) {
return Array.from(values.keys())[index] ?? null
},
removeItem(key: string) {
values.delete(key)
},
setItem(key: string, value: string) {
values.set(key, String(value))
}
}
}
if (typeof globalThis.localStorage === 'undefined' || typeof globalThis.localStorage.getItem !== 'function') {
Object.defineProperty(globalThis, 'localStorage', {
configurable: true,
value: createMemoryStorage()
})
}
if (typeof window !== 'undefined' && typeof window.localStorage.getItem !== 'function') {
Object.defineProperty(window, 'localStorage', {
configurable: true,
value: globalThis.localStorage
})
}
// Mock requestIdleCallback (Safari < 15 不支持)
if (typeof globalThis.requestIdleCallback === 'undefined') {
globalThis.requestIdleCallback = ((callback: IdleRequestCallback) => {

View File

@ -505,6 +505,7 @@ export interface SystemSettings {
enable_anthropic_cache_ttl_1h_injection: boolean;
rewrite_message_cache_control: boolean;
antigravity_user_agent_version: string;
openai_codex_user_agent: string;
web_search_emulation_enabled?: boolean;
// Payment configuration
@ -726,6 +727,7 @@ export interface UpdateSettingsRequest {
enable_anthropic_cache_ttl_1h_injection?: boolean;
rewrite_message_cache_control?: boolean;
antigravity_user_agent_version?: string;
openai_codex_user_agent?: string;
// Payment configuration
payment_enabled?: boolean;
risk_control_enabled?: boolean;
@ -854,6 +856,105 @@ export async function sendTestEmail(
return data;
}
// ==================== Email Template Settings ====================
export interface EmailTemplateOption {
value: string;
label?: string;
description?: string;
}
export type EmailTemplateEventOption = string | EmailTemplateOption;
export interface EmailTemplateSummary {
event: string;
locale: string;
subject: string;
is_custom?: boolean;
updated_at?: string;
}
export interface EmailTemplateListResponse {
events: EmailTemplateEventOption[];
locales: string[];
templates?: EmailTemplateSummary[];
placeholders?: string[];
}
export interface EmailTemplateDetail {
event: string;
locale: string;
subject: string;
html: string;
is_custom?: boolean;
updated_at?: string;
placeholders?: string[];
}
export interface UpdateEmailTemplateRequest {
subject: string;
html: string;
}
export interface PreviewEmailTemplateRequest extends UpdateEmailTemplateRequest {
event: string;
locale: string;
}
export interface EmailTemplatePreviewResponse {
subject: string;
html: string;
}
export async function getEmailTemplates(): Promise<EmailTemplateListResponse> {
const { data } = await apiClient.get<EmailTemplateListResponse>(
"/admin/settings/email-templates",
);
return data;
}
export async function getEmailTemplate(
event: string,
locale: string,
): Promise<EmailTemplateDetail> {
const { data } = await apiClient.get<EmailTemplateDetail>(
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`,
);
return data;
}
export async function updateEmailTemplate(
event: string,
locale: string,
request: UpdateEmailTemplateRequest,
): Promise<EmailTemplateDetail> {
const { data } = await apiClient.put<EmailTemplateDetail>(
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`,
request,
);
return data;
}
export async function restoreOfficialEmailTemplate(
event: string,
locale: string,
): Promise<EmailTemplateDetail> {
const { data } = await apiClient.post<EmailTemplateDetail>(
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}/restore-official`,
);
return data;
}
export async function previewEmailTemplate(
request: PreviewEmailTemplateRequest,
): Promise<EmailTemplatePreviewResponse> {
const { data } = await apiClient.post<EmailTemplatePreviewResponse>(
"/admin/settings/email-template-preview",
request,
);
return data;
}
/**
* Admin API Key status response
*/
@ -1160,6 +1261,11 @@ export const settingsAPI = {
updateSettings,
testSmtpConnection,
sendTestEmail,
getEmailTemplates,
getEmailTemplate,
updateEmailTemplate,
restoreOfficialEmailTemplate,
previewEmailTemplate,
getAdminApiKey,
regenerateAdminApiKey,
deleteAdminApiKey,

View File

@ -69,6 +69,25 @@ export interface ModelStatsResponse {
end_date: string
}
export interface ApiKeyDailyUsagePoint {
date: string
requests: number
input_tokens: number
output_tokens: number
cache_read_tokens: number
cache_write_tokens: number
total_tokens: number
cost: number
actual_cost: number
}
export interface ApiKeyDailyUsageResponse {
items: ApiKeyDailyUsagePoint[]
days: number
start_date: string
end_date: string
}
/**
* List usage logs with optional filters
* @param page - Page number (default: 1)
@ -234,6 +253,23 @@ export async function getDashboardModels(params?: {
return data
}
/**
* Get daily usage details for one API key owned by the current user.
* @param apiKeyId - API key ID
* @param days - Number of days to include (1-90)
* @returns Daily usage detail rows
*/
export async function getMyApiKeyDailyUsage(
apiKeyId: number,
days: number = 30
): Promise<ApiKeyDailyUsageResponse> {
const { data } = await apiClient.get<ApiKeyDailyUsageResponse>(
`/user/api-keys/${apiKeyId}/usage/daily`,
{ params: { days } }
)
return data
}
export interface BatchApiKeyUsageStats {
api_key_id: number
today_actual_cost: number
@ -279,6 +315,7 @@ export const usageAPI = {
getDashboardStats,
getDashboardTrend,
getDashboardModels,
getMyApiKeyDailyUsage,
getDashboardApiKeysUsage
}

View File

@ -123,19 +123,23 @@ export default {
dateRangeToday: 'Today',
dateRange7d: '7 Days',
dateRange30d: '30 Days',
dateRange90d: '90 Days',
dateRangeCustom: 'Custom',
apply: 'Apply',
used: 'Used',
detailInfo: 'Detail Information',
tokenStats: 'Token Statistics',
dailyDetail: 'Daily Detail',
modelStats: 'Model Usage Statistics',
// Table headers
date: 'Date',
model: 'Model',
requests: 'Requests',
inputTokens: 'Input Tokens',
outputTokens: 'Output Tokens',
cacheCreationTokens: 'Cache Creation',
cacheReadTokens: 'Cache Read',
cacheWriteTokens: 'Cache Write',
totalTokens: 'Total Tokens',
cost: 'Cost',
// Status
@ -179,6 +183,7 @@ export default {
querySuccess: 'Query successful',
queryFailed: 'Query failed',
queryFailedRetry: 'Query failed, please try again later',
noDailyUsage: 'No daily usage data',
},
// Setup Wizard
@ -4176,6 +4181,22 @@ export default {
},
userPrefix: 'User #{id}',
exportCsv: 'Export CSV',
batchUpdate: 'Batch Update',
batchUpdateTitle: 'Batch Update Redeem Codes',
selectedCount: '{count} redeem code(s) selected',
clearSelection: 'Clear selection',
selectCodesFirst: 'Select redeem codes first',
noBatchFieldsSelected: 'Select at least one field to update',
batchUpdateSuccess: 'Updated {count} redeem code(s)',
failedToBatchUpdate: 'Failed to batch update redeem codes',
batchFields: {
status: 'Status',
expiresAt: 'Expires At',
notes: 'Notes',
group: 'Group'
},
batchNotesPlaceholder: 'Enter the new note, or leave blank to clear it',
clearGroup: 'Clear group',
deleteAllUnused: 'Delete All Unused Codes',
deleteCode: 'Delete Redeem Code',
deleteCodeConfirm:
@ -5515,6 +5536,9 @@ export default {
antigravityUserAgentVersion: 'Antigravity UA Version',
antigravityUserAgentVersionPlaceholder: '1.23.2',
antigravityUserAgentVersionHint: 'Leave empty to use ANTIGRAVITY_USER_AGENT_VERSION or the built-in default 1.23.2; when set, the admin setting takes precedence.',
openaiCodexUserAgent: 'OpenAI Codex UA',
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.',
},
webSearchEmulation: {
title: 'Web Search Emulation',
@ -5854,6 +5878,36 @@ export default {
sending: 'Sending...',
enterRecipientHint: 'Please enter a recipient email address'
},
emailTemplates: {
title: 'Email Templates',
description: 'Customize notification email subjects and HTML content for each event and locale.',
event: 'Event',
locale: 'Locale',
localeEn: 'English',
localeZh: 'Chinese',
subject: 'Subject',
subjectPlaceholder: 'Enter the email subject',
html: 'HTML Template',
htmlPlaceholder: 'Edit the email HTML template',
placeholders: 'Available Placeholders',
placeholdersHelp: 'Click a placeholder to copy it. The backend replaces these values when sending emails.',
livePreview: 'Live Preview',
previewSecurityHint: 'Preview HTML is generated by the backend preview endpoint and displayed in a sandboxed iframe with scripts disabled.',
preview: 'Preview / Refresh',
previewing: 'Previewing...',
save: 'Save Template',
saving: 'Saving...',
restoreOfficial: 'Restore Official',
restoring: 'Restoring...',
restoreConfirm: 'Restore the official template for this event and locale? Your custom version will be replaced.',
restoreSuccess: 'Official template restored',
saveSuccess: 'Email template saved',
placeholderCopied: 'Placeholder copied',
validationRequired: 'Subject and HTML template are required',
empty: 'No email template events or locales are available yet.',
noPreview: 'Refresh the preview to see the rendered email subject.',
customized: 'Customized'
},
opsMonitoring: {
title: 'Ops Monitoring',
description: 'Enable ops monitoring for troubleshooting and health visibility',

View File

@ -123,19 +123,23 @@ export default {
dateRangeToday: '今日',
dateRange7d: '7 天',
dateRange30d: '30 天',
dateRange90d: '90 天',
dateRangeCustom: '自定义',
apply: '应用',
used: '已使用',
detailInfo: '详细信息',
tokenStats: 'Token 统计',
dailyDetail: '按日明细',
modelStats: '模型用量统计',
// Table headers
date: '日期',
model: '模型',
requests: '请求数',
inputTokens: '输入 Tokens',
outputTokens: '输出 Tokens',
cacheCreationTokens: '缓存创建',
cacheReadTokens: '缓存读取',
cacheWriteTokens: '缓存写入',
totalTokens: '总 Tokens',
cost: '费用',
// Status
@ -179,6 +183,7 @@ export default {
querySuccess: '查询成功',
queryFailed: '查询失败',
queryFailedRetry: '查询失败,请稍后重试',
noDailyUsage: '暂无按日用量数据',
},
// Setup Wizard
@ -4310,6 +4315,22 @@ export default {
used: '已使用',
searchCodes: '搜索兑换码或邮箱...',
exportCsv: '导出 CSV',
batchUpdate: '批量修改',
batchUpdateTitle: '批量修改兑换码',
selectedCount: '已选择 {count} 个兑换码',
clearSelection: '清空选择',
selectCodesFirst: '请先选择兑换码',
noBatchFieldsSelected: '请至少勾选一个要修改的字段',
batchUpdateSuccess: '成功修改 {count} 个兑换码',
failedToBatchUpdate: '批量修改兑换码失败',
batchFields: {
status: '状态',
expiresAt: '过期时间',
notes: '备注',
group: '分组'
},
batchNotesPlaceholder: '输入新的备注,留空可清空备注',
clearGroup: '清空分组',
deleteAllUnused: '删除全部未使用',
deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。',
deleteAllUnusedConfirm: '确定要删除全部未使用的兑换码吗?此操作无法撤销。',
@ -5673,6 +5694,9 @@ export default {
antigravityUserAgentVersion: 'Antigravity UA 版本',
antigravityUserAgentVersionPlaceholder: '1.23.2',
antigravityUserAgentVersionHint: '留空时使用 ANTIGRAVITY_USER_AGENT_VERSION 或内置默认值 1.23.2;填写后后台设置优先。',
openaiCodexUserAgent: 'OpenAI Codex UA',
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。',
},
webSearchEmulation: {
title: 'Web Search 模拟',
@ -6014,6 +6038,36 @@ export default {
sending: '发送中...',
enterRecipientHint: '请输入收件人邮箱地址'
},
emailTemplates: {
title: '邮件模板',
description: '按事件和语言自定义通知邮件主题与 HTML 内容。',
event: '事件',
locale: '语言',
localeEn: '英文',
localeZh: '中文',
subject: '主题',
subjectPlaceholder: '输入邮件主题',
html: 'HTML 模板',
htmlPlaceholder: '编辑邮件 HTML 模板',
placeholders: '可用占位符',
placeholdersHelp: '点击占位符可复制。后端发送邮件时会替换这些值。',
livePreview: '实时预览',
previewSecurityHint: '预览 HTML 由后端预览接口生成,并在禁用脚本的沙盒 iframe 中展示。',
preview: '预览 / 刷新',
previewing: '预览中...',
save: '保存模板',
saving: '保存中...',
restoreOfficial: '恢复官方模板',
restoring: '恢复中...',
restoreConfirm: '确定恢复此事件和语言的官方模板吗?当前自定义版本将被替换。',
restoreSuccess: '已恢复官方模板',
saveSuccess: '邮件模板已保存',
placeholderCopied: '占位符已复制',
validationRequired: '主题和 HTML 模板不能为空',
empty: '暂无可用的邮件模板事件或语言。',
noPreview: '刷新预览后查看渲染后的邮件主题。',
customized: '已自定义'
},
opsMonitoring: {
title: '运维监控',
description: '启用运维监控模块,用于排障与健康可视化',

View File

@ -289,6 +289,62 @@
</div>
</div>
<!-- Daily Usage Table -->
<div
v-if="showDailyUsage"
class="fade-up fade-up-delay-4 rounded-2xl border border-gray-200 bg-white/90 backdrop-blur-sm overflow-hidden dark:border-dark-700 dark:bg-dark-900/90"
>
<div class="flex flex-col gap-3 px-8 py-5 border-b border-gray-200 dark:border-dark-700 sm:flex-row sm:items-center sm:justify-between">
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.dailyDetail') }}</h3>
<div class="inline-flex rounded-lg border border-gray-200 bg-white p-0.5 dark:border-dark-700 dark:bg-dark-950">
<button
v-for="option in dailyUsageOptions"
:key="option.value"
@click="setDailyUsageDays(option.value)"
class="min-w-12 rounded-md px-3 py-1.5 text-xs font-medium transition-colors"
:class="dailyUsageDays === option.value
? 'bg-primary-500 text-white'
: 'text-gray-600 hover:bg-gray-100 dark:text-dark-300 dark:hover:bg-dark-800'"
>
{{ option.label }}
</button>
</div>
</div>
<div v-if="dailyUsageRows.length > 0" class="overflow-x-auto">
<table class="w-full">
<thead>
<tr class="border-b border-gray-200 bg-gray-50 dark:border-dark-700 dark:bg-dark-950">
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.date') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.requests') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.inputTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.outputTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheReadTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheWriteTokens') }}</th>
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cost') }}</th>
</tr>
</thead>
<tbody>
<tr
v-for="row in dailyUsageRows"
:key="row.date"
class="border-b border-gray-100 last:border-b-0 dark:border-dark-800"
>
<td class="px-4 py-3 text-sm font-medium whitespace-nowrap text-gray-900 dark:text-white">{{ row.date }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.requests) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.input_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.output_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.cache_read_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.cache_write_tokens) }}</td>
<td class="px-4 py-3 text-sm tabular-nums text-right font-medium text-gray-900 dark:text-white">{{ usd(row.actual_cost != null ? row.actual_cost : row.cost) }}</td>
</tr>
</tbody>
</table>
</div>
<div v-else class="px-8 py-8 text-center text-sm text-gray-500 dark:text-dark-400">
{{ t('keyUsage.noDailyUsage') }}
</div>
</div>
<!-- Model Stats Table -->
<div
v-if="modelStats.length > 0"
@ -408,6 +464,7 @@ type DateRangeKey = 'today' | '7d' | '30d' | 'custom'
const currentRange = ref<DateRangeKey>('today')
const customStartDate = ref('')
const customEndDate = ref('')
const dailyUsageDays = ref<7 | 30 | 90>(30)
const dateRanges = computed(() => [
{ key: 'today' as const, label: t('keyUsage.dateRangeToday') },
@ -416,6 +473,12 @@ const dateRanges = computed(() => [
{ key: 'custom' as const, label: t('keyUsage.dateRangeCustom') },
])
const dailyUsageOptions = computed(() => [
{ value: 7 as const, label: t('keyUsage.dateRange7d') },
{ value: 30 as const, label: t('keyUsage.dateRange30d') },
{ value: 90 as const, label: t('keyUsage.dateRange90d') },
])
function setDateRange(key: DateRangeKey) {
currentRange.value = key
if (key !== 'custom') {
@ -426,23 +489,36 @@ function setDateRange(key: DateRangeKey) {
function getDateParams(): string {
const now = new Date()
const fmt = (d: Date) => d.toISOString().split('T')[0]
const params = new URLSearchParams()
if (currentRange.value === 'custom') {
if (customStartDate.value && customEndDate.value) {
return `start_date=${customStartDate.value}&end_date=${customEndDate.value}`
params.set('start_date', customStartDate.value)
params.set('end_date', customEndDate.value)
}
return ''
} else {
const end = fmt(now)
let start: string
switch (currentRange.value) {
case 'today': start = end; break
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
}
params.set('start_date', start)
params.set('end_date', end)
}
params.set('days', String(dailyUsageDays.value))
params.set('timezone', getBrowserTimezone())
return params.toString()
}
const end = fmt(now)
let start: string
switch (currentRange.value) {
case 'today': start = end; break
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
function setDailyUsageDays(days: 7 | 30 | 90) {
if (dailyUsageDays.value === days) return
dailyUsageDays.value = days
if (resultData.value && apiKey.value.trim()) {
queryKey()
}
return `start_date=${start}&end_date=${end}`
}
// ==================== Ring Animation ====================
@ -731,6 +807,24 @@ const usageStatCells = computed<StatCell[]>(() => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const modelStats = computed<any[]>(() => resultData.value?.model_stats || [])
interface DailyUsageRow {
date: string
requests: number
input_tokens: number
output_tokens: number
cache_read_tokens: number
cache_write_tokens: number
cost: number
actual_cost?: number
}
const dailyUsageRows = computed<DailyUsageRow[]>(() => {
const rows = resultData.value?.daily_usage
return Array.isArray(rows) ? rows : []
})
const showDailyUsage = computed(() => Boolean(resultData.value && Array.isArray(resultData.value.daily_usage)))
// ==================== Utility Functions ====================
function usd(value: number | null | undefined): string {
@ -750,6 +844,14 @@ function formatDate(iso: string | null | undefined): string {
return d.toLocaleDateString(loc, { year: 'numeric', month: 'long', day: 'numeric' })
}
function getBrowserTimezone(): string {
try {
return Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC'
} catch {
return 'UTC'
}
}
// ==================== API Query ====================
async function fetchUsage(key: string) {

View File

@ -0,0 +1,208 @@
import { describe, expect, it, beforeEach, afterEach, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
import { nextTick } from 'vue'
import KeyUsageView from '../KeyUsageView.vue'
const { showInfo, showSuccess, showError, fetchPublicSettings } = vi.hoisted(() => ({
showInfo: vi.fn(),
showSuccess: vi.fn(),
showError: vi.fn(),
fetchPublicSettings: vi.fn(),
}))
const messages: Record<string, string> = {
'keyUsage.title': 'API Key Usage',
'keyUsage.subtitle': 'Usage status',
'keyUsage.placeholder': 'sk-test',
'keyUsage.query': 'Query',
'keyUsage.querying': 'Querying...',
'keyUsage.privacyNote': 'Privacy note',
'keyUsage.dateRange': 'Date Range:',
'keyUsage.dateRangeToday': 'Today',
'keyUsage.dateRange7d': '7 Days',
'keyUsage.dateRange30d': '30 Days',
'keyUsage.dateRange90d': '90 Days',
'keyUsage.dateRangeCustom': 'Custom',
'keyUsage.apply': 'Apply',
'keyUsage.used': 'Used',
'keyUsage.detailInfo': 'Detail Information',
'keyUsage.tokenStats': 'Token Statistics',
'keyUsage.dailyDetail': 'Daily Detail',
'keyUsage.date': 'Date',
'keyUsage.requests': 'Requests',
'keyUsage.inputTokens': 'Input Tokens',
'keyUsage.outputTokens': 'Output Tokens',
'keyUsage.cacheReadTokens': 'Cache Read',
'keyUsage.cacheWriteTokens': 'Cache Write',
'keyUsage.cost': 'Cost',
'keyUsage.quotaMode': 'Key Quota Mode',
'keyUsage.walletBalance': 'Wallet Balance',
'keyUsage.totalQuota': 'Total Quota',
'keyUsage.limit5h': '5-Hour Limit',
'keyUsage.limitDaily': 'Daily Limit',
'keyUsage.limit7d': '7-Day Limit',
'keyUsage.limitWeekly': 'Weekly Limit',
'keyUsage.limitMonthly': 'Monthly Limit',
'keyUsage.remainingQuota': 'Remaining Quota',
'keyUsage.usedQuota': 'Used Quota',
'keyUsage.subscriptionType': 'Subscription Type',
'keyUsage.todayRequests': 'Today Requests',
'keyUsage.todayInputTokens': 'Today Input',
'keyUsage.todayOutputTokens': 'Today Output',
'keyUsage.todayTokens': 'Today Tokens',
'keyUsage.todayCacheCreation': 'Today Cache Creation',
'keyUsage.todayCacheRead': 'Today Cache Read',
'keyUsage.todayCost': 'Today Cost',
'keyUsage.rpmTpm': 'RPM / TPM',
'keyUsage.totalRequests': 'Total Requests',
'keyUsage.totalInputTokens': 'Total Input',
'keyUsage.totalOutputTokens': 'Total Output',
'keyUsage.totalTokensLabel': 'Total Tokens',
'keyUsage.totalCacheCreation': 'Total Cache Creation',
'keyUsage.totalCacheRead': 'Total Cache Read',
'keyUsage.totalCost': 'Total Cost',
'keyUsage.avgDuration': 'Avg Duration',
'keyUsage.querySuccess': 'Query successful',
'keyUsage.queryFailed': 'Query failed',
'keyUsage.queryFailedRetry': 'Query failed, please try again later',
'home.viewDocs': 'Docs',
'home.switchToLight': 'Light',
'home.switchToDark': 'Dark',
'home.footer.allRightsReserved': 'All rights reserved.',
}
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => messages[key] ?? key,
locale: { value: 'en' },
}),
}
})
vi.mock('@/stores', () => ({
useAppStore: () => ({
cachedPublicSettings: null,
siteName: 'Sub2API',
siteLogo: '',
docUrl: '',
publicSettingsLoaded: true,
fetchPublicSettings,
showInfo,
showSuccess,
showError,
}),
}))
describe('KeyUsageView daily detail', () => {
beforeEach(() => {
showInfo.mockReset()
showSuccess.mockReset()
showError.mockReset()
fetchPublicSettings.mockReset()
localStorage.clear()
Object.defineProperty(window, 'matchMedia', {
configurable: true,
value: vi.fn().mockReturnValue({ matches: false }),
})
vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => window.setTimeout(() => cb(0), 0))
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
mode: 'quota_limited',
isValid: true,
status: 'active',
quota: {
limit: 10,
used: 1,
remaining: 9,
unit: 'USD',
},
usage: {
today: {
requests: 1,
input_tokens: 10,
output_tokens: 20,
cache_creation_tokens: 0,
cache_read_tokens: 0,
total_tokens: 30,
actual_cost: 0.01,
},
total: {
requests: 12,
input_tokens: 100,
output_tokens: 200,
cache_creation_tokens: 10,
cache_read_tokens: 30,
total_tokens: 340,
actual_cost: 0.12,
},
rpm: 0,
tpm: 0,
},
daily_usage: [
{
date: '2026-05-19',
requests: 12,
input_tokens: 100,
output_tokens: 200,
cache_read_tokens: 30,
cache_write_tokens: 10,
total_tokens: 340,
cost: 0.15,
actual_cost: 0.12,
},
],
}),
}))
})
afterEach(() => {
vi.unstubAllGlobals()
})
it('renders daily usage detail rows after a successful query', async () => {
const wrapper = mount(KeyUsageView, {
global: {
stubs: {
RouterLink: { template: '<a><slot /></a>' },
LocaleSwitcher: true,
Icon: true,
},
},
})
await wrapper.find('input').setValue('sk-test-key')
await wrapper.find('input').trigger('keydown.enter')
await flushPromises()
await nextTick()
const fetchMock = vi.mocked(fetch)
expect(fetchMock).toHaveBeenCalledWith(
expect.stringContaining('/v1/usage?'),
expect.objectContaining({
headers: { Authorization: 'Bearer sk-test-key' },
})
)
expect(String(fetchMock.mock.calls[0][0])).toContain('days=30')
const text = wrapper.text()
expect(text).toContain('Daily Detail')
expect(text).toContain('Date')
expect(text).toContain('Cache Read')
expect(text).toContain('Cache Write')
expect(text).toContain('2026-05-19')
expect(text).toContain('12')
expect(text).toContain('100')
expect(text).toContain('200')
expect(text).toContain('30')
expect(text).toContain('10')
expect(text).toContain('$0.12')
wrapper.unmount()
})
})

View File

@ -3769,6 +3769,36 @@
}}
</p>
</div>
<!-- OpenAI Codex UA -->
<div>
<label
class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"
>
{{
t(
"admin.settings.gatewayForwarding.openaiCodexUserAgent",
)
}}
</label>
<input
v-model="form.openai_codex_user_agent"
type="text"
class="input w-full font-mono text-sm"
:placeholder="
t(
'admin.settings.gatewayForwarding.openaiCodexUserAgentPlaceholder',
)
"
/>
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
{{
t(
"admin.settings.gatewayForwarding.openaiCodexUserAgentHint",
)
}}
</p>
</div>
</div>
</div>
<!-- Web Search Emulation -->
@ -6225,6 +6255,9 @@
</div>
</div>
</div>
<EmailTemplateEditor />
<!-- Balance Low Notification -->
<div class="card">
<div
@ -6482,6 +6515,7 @@ import Toggle from "@/components/common/Toggle.vue";
import ProxySelector from "@/components/common/ProxySelector.vue";
import ImageUpload from "@/components/common/ImageUpload.vue";
import BackupSettings from "@/views/admin/BackupView.vue";
import EmailTemplateEditor from "@/views/admin/settings/EmailTemplateEditor.vue";
import { useClipboard } from "@/composables/useClipboard";
import { affiliatesAPI, type AffiliateAdminEntry, type SimpleUser as AffiliateSimpleUser } from "@/api/admin/affiliates";
import { extractApiErrorMessage, extractI18nErrorMessage } from "@/utils/apiError";
@ -6943,6 +6977,7 @@ const form = reactive<SettingsForm>({
enable_anthropic_cache_ttl_1h_injection: false,
rewrite_message_cache_control: false,
antigravity_user_agent_version: "",
openai_codex_user_agent: "",
// Balance & quota notification
balance_low_notify_enabled: false,
balance_low_notify_threshold: 0,
@ -8044,6 +8079,8 @@ async function saveSettings() {
rewrite_message_cache_control: form.rewrite_message_cache_control,
antigravity_user_agent_version:
form.antigravity_user_agent_version?.trim() || "",
openai_codex_user_agent:
form.openai_codex_user_agent?.trim() || "",
// Payment configuration
payment_enabled: form.payment_enabled,
risk_control_enabled: form.risk_control_enabled,

View File

@ -371,6 +371,7 @@ const baseSettingsResponse = {
enable_anthropic_cache_ttl_1h_injection: false,
rewrite_message_cache_control: false,
antigravity_user_agent_version: "",
openai_codex_user_agent: "",
payment_enabled: true,
payment_min_amount: 1,
payment_max_amount: 10000,

View File

@ -0,0 +1,483 @@
<template>
<div class="card">
<div
class="flex flex-col gap-3 border-b border-gray-100 px-6 py-4 dark:border-dark-700 lg:flex-row lg:items-start lg:justify-between"
>
<div>
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
{{ t("admin.settings.emailTemplates.title") }}
</h2>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t("admin.settings.emailTemplates.description") }}
</p>
</div>
<div class="flex flex-wrap gap-2">
<button
type="button"
class="btn btn-secondary btn-sm"
:disabled="loadingTemplate || previewing || !canPreview"
@click="refreshPreview"
>
{{ previewing ? t("admin.settings.emailTemplates.previewing") : t("admin.settings.emailTemplates.preview") }}
</button>
<button
type="button"
class="btn btn-secondary btn-sm"
:disabled="loadingTemplate || restoring || !selectedEvent || !selectedLocale"
@click="restoreOfficial"
>
{{ restoring ? t("admin.settings.emailTemplates.restoring") : t("admin.settings.emailTemplates.restoreOfficial") }}
</button>
<button
type="button"
class="btn btn-primary btn-sm"
:disabled="loadingTemplate || saving || !canSave"
@click="saveTemplate"
>
{{ saving ? t("admin.settings.emailTemplates.saving") : t("admin.settings.emailTemplates.save") }}
</button>
</div>
</div>
<div class="space-y-6 p-6">
<div
v-if="loadingList"
class="flex items-center gap-2 text-sm text-gray-500 dark:text-gray-400"
>
<span
class="h-4 w-4 animate-spin rounded-full border-b-2 border-primary-600"
></span>
{{ t("common.loading") }}
</div>
<template v-else>
<div class="grid grid-cols-1 gap-4 md:grid-cols-2">
<div>
<label class="input-label" for="email-template-event">
{{ t("admin.settings.emailTemplates.event") }}
</label>
<select
id="email-template-event"
v-model="selectedEvent"
class="input"
:disabled="loadingTemplate || eventOptions.length === 0"
>
<option
v-for="option in eventOptions"
:key="option.value"
:value="option.value"
>
{{ option.label || option.value }}
</option>
</select>
<p v-if="selectedEventDescription" class="input-hint">
{{ selectedEventDescription }}
</p>
</div>
<div>
<label class="input-label" for="email-template-locale">
{{ t("admin.settings.emailTemplates.locale") }}
</label>
<select
id="email-template-locale"
v-model="selectedLocale"
class="input"
:disabled="loadingTemplate || localeOptions.length === 0"
>
<option
v-for="localeOption in localeOptions"
:key="localeOption"
:value="localeOption"
>
{{ formatLocale(localeOption) }}
</option>
</select>
</div>
</div>
<div
v-if="!eventOptions.length || !localeOptions.length"
class="rounded-lg border border-amber-200 bg-amber-50 p-4 text-sm text-amber-700 dark:border-amber-800 dark:bg-amber-900/20 dark:text-amber-300"
>
{{ t("admin.settings.emailTemplates.empty") }}
</div>
<div v-else class="grid grid-cols-1 gap-6 xl:grid-cols-2">
<div class="space-y-4">
<div>
<label class="input-label" for="email-template-subject">
{{ t("admin.settings.emailTemplates.subject") }}
</label>
<input
id="email-template-subject"
v-model="subject"
type="text"
class="input"
:disabled="loadingTemplate"
:placeholder="t('admin.settings.emailTemplates.subjectPlaceholder')"
/>
</div>
<div>
<label class="input-label" for="email-template-html">
{{ t("admin.settings.emailTemplates.html") }}
</label>
<textarea
id="email-template-html"
v-model="html"
rows="18"
class="input min-h-[28rem] resize-y font-mono text-sm leading-6"
:disabled="loadingTemplate"
:placeholder="t('admin.settings.emailTemplates.htmlPlaceholder')"
></textarea>
</div>
<div
class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800/60"
>
<div class="text-sm font-medium text-gray-900 dark:text-white">
{{ t("admin.settings.emailTemplates.placeholders") }}
</div>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t("admin.settings.emailTemplates.placeholdersHelp") }}
</p>
<div class="mt-3 flex flex-wrap gap-2">
<button
v-for="placeholder in placeholderList"
:key="placeholder"
type="button"
class="rounded-full border border-gray-200 bg-white px-3 py-1 font-mono text-xs text-gray-700 transition-colors hover:border-primary-300 hover:text-primary-600 dark:border-dark-600 dark:bg-dark-700 dark:text-gray-200 dark:hover:border-primary-500 dark:hover:text-primary-300"
@click="copyPlaceholder(placeholder)"
>
{{ placeholder }}
</button>
</div>
</div>
</div>
<div class="space-y-4">
<div
class="rounded-lg border border-gray-200 bg-white dark:border-dark-700 dark:bg-dark-800"
>
<div
class="flex items-center justify-between border-b border-gray-100 px-4 py-3 dark:border-dark-700"
>
<div>
<div class="text-sm font-medium text-gray-900 dark:text-white">
{{ t("admin.settings.emailTemplates.livePreview") }}
</div>
<div class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
{{ previewSubject || t("admin.settings.emailTemplates.noPreview") }}
</div>
</div>
<span
v-if="isCustomTemplate"
class="rounded-full bg-primary-50 px-2.5 py-1 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
>
{{ t("admin.settings.emailTemplates.customized") }}
</span>
</div>
<div class="bg-gray-100 p-3 dark:bg-dark-900">
<iframe
class="h-[36rem] w-full rounded-md border border-gray-200 bg-white dark:border-dark-700"
sandbox=""
:srcdoc="previewHtml"
:title="t('admin.settings.emailTemplates.livePreview')"
></iframe>
</div>
</div>
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t("admin.settings.emailTemplates.previewSecurityHint") }}
</p>
</div>
</div>
</template>
</div>
</div>
</template>
<script setup lang="ts">
import { computed, onMounted, ref, watch } from "vue";
import { useI18n } from "vue-i18n";
import { adminAPI } from "@/api";
import type {
EmailTemplateEventOption,
EmailTemplateOption,
} from "@/api/admin/settings";
import { useAppStore } from "@/stores";
import { extractApiErrorMessage } from "@/utils/apiError";
const { t, locale } = useI18n();
const appStore = useAppStore();
const fallbackPlaceholders = [
"{{site_name}}",
"{{recipient_name}}",
"{{recipient_email}}",
"{{verification_code}}",
"{{expires_in_minutes}}",
"{{reset_url}}",
"{{subscription_group}}",
"{{subscription_days}}",
"{{expiry_time}}",
"{{days_remaining}}",
"{{current_balance}}",
"{{threshold}}",
"{{recharge_url}}",
"{{recharge_amount}}",
"{{order_id}}",
"{{unsubscribe_url}}",
"{{account_id}}",
"{{account_name}}",
"{{platform}}",
"{{quota_dimension}}",
"{{quota_used}}",
"{{quota_limit}}",
"{{quota_remaining}}",
"{{quota_threshold}}",
"{{triggered_at}}",
"{{group_name}}",
"{{moderation_category}}",
"{{moderation_score}}",
"{{violation_count}}",
"{{ban_threshold}}",
"{{rule_name}}",
"{{severity}}",
"{{alert_status}}",
"{{metric_type}}",
"{{operator}}",
"{{metric_value}}",
"{{threshold_value}}",
"{{alert_description}}",
"{{report_name}}",
"{{report_type}}",
"{{report_start_time}}",
"{{report_end_time}}",
"{{report_html}}",
];
const loadingList = ref(true);
const loadingTemplate = ref(false);
const saving = ref(false);
const previewing = ref(false);
const restoring = ref(false);
const eventOptions = ref<EmailTemplateOption[]>([]);
const localeOptions = ref<string[]>([]);
const selectedEvent = ref("");
const selectedLocale = ref("");
const subject = ref("");
const html = ref("");
const isCustomTemplate = ref(false);
const placeholders = ref<string[]>([]);
const previewSubject = ref("");
const previewHtml = ref("");
const initializingSelection = ref(false);
function normalizeEventOption(option: EmailTemplateEventOption): EmailTemplateOption {
if (typeof option === "string") {
return { value: option };
}
return option;
}
const selectedEventDescription = computed(() => {
return (
eventOptions.value.find((option) => option.value === selectedEvent.value)
?.description || ""
);
});
const placeholderList = computed(() => {
const combined = [...placeholders.value, ...fallbackPlaceholders];
return Array.from(
new Set(
combined
.map((item) => formatPlaceholder(item))
.filter((item) => item.length > 0),
),
);
});
function formatPlaceholder(placeholder: string): string {
const trimmed = placeholder.trim();
if (!trimmed) return "";
if (trimmed.startsWith("{{") && trimmed.endsWith("}}")) return trimmed;
return `{{${trimmed}}}`;
}
const canSave = computed(
() =>
Boolean(selectedEvent.value && selectedLocale.value) &&
subject.value.trim().length > 0 &&
html.value.trim().length > 0,
);
const canPreview = computed(
() => Boolean(selectedEvent.value && selectedLocale.value) && html.value.trim().length > 0,
);
function formatLocale(locale: string): string {
const lower = locale.toLowerCase();
if (lower === "zh" || lower.startsWith("zh-")) {
return t("admin.settings.emailTemplates.localeZh");
}
if (lower === "en" || lower.startsWith("en-")) {
return t("admin.settings.emailTemplates.localeEn");
}
return locale;
}
function selectInitialLocale(locales: string[]): string {
const currentLocale = locale.value.toLowerCase();
const exactMatch = locales.find(
(availableLocale) => availableLocale.toLowerCase() === currentLocale,
);
if (exactMatch) return exactMatch;
const currentLanguage = currentLocale.split("-")[0];
const languageMatch = locales.find(
(availableLocale) => availableLocale.toLowerCase().split("-")[0] === currentLanguage,
);
if (languageMatch) return languageMatch;
return locales[0] || "";
}
function applyTemplate(template: {
subject: string;
html: string;
is_custom?: boolean;
placeholders?: string[];
}) {
subject.value = template.subject;
html.value = template.html;
isCustomTemplate.value = template.is_custom === true;
placeholders.value = template.placeholders || [];
}
async function loadTemplate() {
if (!selectedEvent.value || !selectedLocale.value) return;
loadingTemplate.value = true;
try {
const template = await adminAPI.settings.getEmailTemplate(
selectedEvent.value,
selectedLocale.value,
);
applyTemplate(template);
await refreshPreview();
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t("common.error")));
} finally {
loadingTemplate.value = false;
}
}
async function loadTemplateList() {
loadingList.value = true;
try {
const response = await adminAPI.settings.getEmailTemplates();
eventOptions.value = response.events.map(normalizeEventOption);
localeOptions.value = response.locales;
placeholders.value = response.placeholders || [];
initializingSelection.value = true;
selectedEvent.value = eventOptions.value[0]?.value || "";
selectedLocale.value = selectInitialLocale(response.locales);
await loadTemplate();
initializingSelection.value = false;
} catch (err: unknown) {
initializingSelection.value = false;
appStore.showError(extractApiErrorMessage(err, t("common.error")));
} finally {
loadingList.value = false;
}
}
async function saveTemplate() {
if (!canSave.value) {
appStore.showError(t("admin.settings.emailTemplates.validationRequired"));
return;
}
saving.value = true;
try {
const template = await adminAPI.settings.updateEmailTemplate(
selectedEvent.value,
selectedLocale.value,
{
subject: subject.value,
html: html.value,
},
);
applyTemplate(template);
await refreshPreview();
appStore.showSuccess(t("admin.settings.emailTemplates.saveSuccess"));
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t("common.error")));
} finally {
saving.value = false;
}
}
async function refreshPreview() {
if (!canPreview.value) {
previewSubject.value = "";
previewHtml.value = "";
return;
}
previewing.value = true;
try {
const preview = await adminAPI.settings.previewEmailTemplate({
event: selectedEvent.value,
locale: selectedLocale.value,
subject: subject.value,
html: html.value,
});
previewSubject.value = preview.subject;
previewHtml.value = preview.html;
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t("common.error")));
} finally {
previewing.value = false;
}
}
async function restoreOfficial() {
if (!selectedEvent.value || !selectedLocale.value) return;
if (!window.confirm(t("admin.settings.emailTemplates.restoreConfirm"))) return;
restoring.value = true;
try {
const template = await adminAPI.settings.restoreOfficialEmailTemplate(
selectedEvent.value,
selectedLocale.value,
);
applyTemplate(template);
await refreshPreview();
appStore.showSuccess(t("admin.settings.emailTemplates.restoreSuccess"));
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t("common.error")));
} finally {
restoring.value = false;
}
}
async function copyPlaceholder(placeholder: string) {
try {
await navigator.clipboard.writeText(placeholder);
appStore.showSuccess(t("admin.settings.emailTemplates.placeholderCopied"));
} catch {
appStore.showError(t("common.error"));
}
}
watch([selectedEvent, selectedLocale], ([eventValue, localeValue], [oldEvent, oldLocale]) => {
if (initializingSelection.value) return;
if (!eventValue || !localeValue) return;
if (eventValue === oldEvent && localeValue === oldLocale) return;
void loadTemplate();
});
onMounted(() => {
void loadTemplateList();
});
</script>

View File

@ -13,6 +13,7 @@ export default defineConfig({
test: {
globals: true,
environment: 'jsdom',
setupFiles: ['./src/__tests__/setup.ts'],
include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'],
exclude: ['node_modules', 'dist'],
coverage: {