diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 74799d81..f3185747 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.127 +0.1.128 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 9afc20ec..af535d5c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -189,7 +189,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) + notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) + balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) rpmTokenBucketService := service.NewRPMTokenBucketService() gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) @@ -204,8 +205,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) registry := payment.ProvideRegistry() defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService) + paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService, notificationEmailService) + settingHandler := handler.ProvideAdminSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService, notificationEmailService) requestEventBus := service.NewRequestEventBus() opsLogBroadcaster := service.ProvideOpsLogBroadcaster() opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster) @@ -253,7 +254,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig) - handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) + handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo, notificationEmailService) totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) @@ -274,7 +275,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) + subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 9907d441..14f5dce0 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -56,13 +56,14 @@ func firstNonEmpty(values ...string) string { // SettingHandler 系统设置处理器 type SettingHandler struct { - settingService *service.SettingService - emailService *service.EmailService - turnstileService *service.TurnstileService - opsService *service.OpsService - paymentConfigService *service.PaymentConfigService - paymentService *service.PaymentService - userAttributeService *service.UserAttributeService + settingService *service.SettingService + emailService *service.EmailService + turnstileService *service.TurnstileService + opsService *service.OpsService + paymentConfigService *service.PaymentConfigService + paymentService *service.PaymentService + userAttributeService *service.UserAttributeService + notificationEmailService *service.NotificationEmailService } // NewSettingHandler 创建系统设置处理器 @@ -78,6 +79,12 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser } } +// SetNotificationEmailService attaches the notification template service without changing +// the constructor signature used by existing unit tests. +func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) { + h.notificationEmailService = notificationEmailService +} + // GetSettings 获取所有系统设置 // GET /api/v1/admin/settings func (h *SettingHandler) GetSettings(c *gin.Context) { @@ -247,6 +254,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, RewriteMessageCacheControl: settings.RewriteMessageCacheControl, AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion, + OpenAICodexUserAgent: settings.OpenAICodexUserAgent, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, @@ -563,6 +571,7 @@ type UpdateSettingsRequest struct { EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"` RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"` AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"` + OpenAICodexUserAgent *string `json:"openai_codex_user_agent"` // Payment visible method routing PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` @@ -1404,6 +1413,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } } + if req.OpenAICodexUserAgent != nil { + normalized := strings.TrimSpace(*req.OpenAICodexUserAgent) + req.OpenAICodexUserAgent = &normalized + // 仅做长度上限保护,不限制具体格式(运维需要可自由调整 codex 版本号) + if len(normalized) > 512 { + response.Error(c, http.StatusBadRequest, "openai_codex_user_agent must be at most 512 characters") + return + } + } // 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号 if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" { @@ -1597,6 +1615,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AntigravityUserAgentVersion }(), + OpenAICodexUserAgent: func() string { + if req.OpenAICodexUserAgent != nil { + return *req.OpenAICodexUserAgent + } + return previousSettings.OpenAICodexUserAgent + }(), PaymentVisibleMethodAlipaySource: func() string { if req.PaymentVisibleMethodAlipaySource != nil { return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) @@ -1956,6 +1980,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection, RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl, AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion, + OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent, PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, @@ -2411,6 +2436,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion { changed = append(changed, "antigravity_user_agent_version") } + if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent { + changed = append(changed, "openai_codex_user_agent") + } if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { changed = append(changed, "payment_visible_method_alipay_source") } @@ -3339,3 +3367,160 @@ func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, } slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType) } + +// ListEmailTemplates returns all editable notification email templates. +// GET /api/v1/admin/settings/email-templates +func (h *SettingHandler) ListEmailTemplates(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + events := h.notificationEmailService.ListEventInfos() + templates, err := h.notificationEmailService.ListTemplates(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, dto.EmailTemplateListResponse{ + Events: emailTemplateEventOptionsToDTO(events), + Locales: h.notificationEmailService.SupportedLocales(), + Templates: emailTemplateSummariesToDTO(templates), + Placeholders: emailTemplatePlaceholderUnion(events), + }) +} + +// GetEmailTemplate returns one editable notification email template. +// GET /api/v1/admin/settings/email-templates/:event/:locale +func (h *SettingHandler) GetEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + tmpl, err := h.notificationEmailService.GetTemplate(c.Request.Context(), c.Param("event"), c.Param("locale")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, emailTemplateDetailToDTO(tmpl)) +} + +// UpdateEmailTemplate saves an override for one event/locale template. +// PUT /api/v1/admin/settings/email-templates/:event/:locale +func (h *SettingHandler) UpdateEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + var req dto.UpdateEmailTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + tmpl, err := h.notificationEmailService.UpdateTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"), req.Subject, req.HTML) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, emailTemplateDetailToDTO(tmpl)) +} + +// RestoreOfficialEmailTemplate removes an override and returns the built-in template. +// POST /api/v1/admin/settings/email-templates/:event/:locale/restore-official +func (h *SettingHandler) RestoreOfficialEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + tmpl, err := h.notificationEmailService.RestoreOfficialTemplate(c.Request.Context(), c.Param("event"), c.Param("locale")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, emailTemplateDetailToDTO(tmpl)) +} + +// PreviewEmailTemplate renders a template with safe sample variables without saving it. +// POST /api/v1/admin/settings/email-templates/preview +func (h *SettingHandler) PreviewEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + var req dto.PreviewEmailTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + preview, err := h.notificationEmailService.PreviewTemplate(c.Request.Context(), service.NotificationEmailPreviewInput{ + Event: req.Event, + Locale: req.Locale, + Subject: req.Subject, + HTML: req.HTML, + Variables: req.Variables, + }) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, dto.EmailTemplatePreviewResponse{Subject: preview.Subject, HTML: preview.HTML}) +} + +func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo) []dto.EmailTemplateEventOption { + items := make([]dto.EmailTemplateEventOption, 0, len(events)) + for _, event := range events { + items = append(items, dto.EmailTemplateEventOption{ + Value: event.Event, + Label: event.Label, + Description: event.Description, + }) + } + return items +} + +func emailTemplateSummariesToDTO(templates []service.NotificationEmailTemplate) []dto.EmailTemplateSummary { + items := make([]dto.EmailTemplateSummary, 0, len(templates)) + for _, tmpl := range templates { + items = append(items, dto.EmailTemplateSummary{ + Event: tmpl.Event, + Locale: tmpl.Locale, + Subject: tmpl.Subject, + IsCustom: tmpl.IsCustom, + UpdatedAt: emailTemplateUpdatedAt(tmpl), + }) + } + return items +} + +func emailTemplateDetailToDTO(tmpl service.NotificationEmailTemplate) dto.EmailTemplateDetail { + return dto.EmailTemplateDetail{ + Event: tmpl.Event, + Locale: tmpl.Locale, + Subject: tmpl.Subject, + HTML: tmpl.HTML, + IsCustom: tmpl.IsCustom, + UpdatedAt: emailTemplateUpdatedAt(tmpl), + Placeholders: tmpl.Placeholders, + } +} + +func emailTemplateUpdatedAt(tmpl service.NotificationEmailTemplate) string { + if tmpl.UpdatedAt == nil { + return "" + } + return tmpl.UpdatedAt.Format("2006-01-02T15:04:05Z07:00") +} + +func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) []string { + seen := make(map[string]struct{}) + placeholders := make([]string, 0) + for _, event := range events { + for _, placeholder := range event.Placeholders { + if _, ok := seen[placeholder]; ok { + continue + } + seen[placeholder] = struct{}{} + placeholders = append(placeholders, placeholder) + } + } + return placeholders +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index a9af910d..592a0d82 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -203,7 +203,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) { return } - result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email) + result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email, c.GetHeader("Accept-Language")) if err != nil { response.ErrorFrom(c, err) return @@ -602,7 +602,7 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { // Request password reset (async) // Note: This returns success even if email doesn't exist (to prevent enumeration) - if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil { + if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL, c.GetHeader("Accept-Language")); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 1014a3e8..550363fd 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -545,7 +545,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { return } - result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email) + result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email, c.GetHeader("Accept-Language")) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 45ad7a70..bdad5572 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -181,6 +181,7 @@ type SystemSettings struct { EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"` RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"` AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"` + OpenAICodexUserAgent string `json:"openai_codex_user_agent"` // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` @@ -377,6 +378,62 @@ type OpenAIFastPolicySettings struct { Rules []OpenAIFastPolicyRule `json:"rules"` } +// EmailTemplateEventOption describes an editable notification email event. +type EmailTemplateEventOption struct { + Value string `json:"value"` + Label string `json:"label,omitempty"` + Description string `json:"description,omitempty"` +} + +// EmailTemplateSummary is shown in the admin email template list. +type EmailTemplateSummary struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + IsCustom bool `json:"is_custom,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +// EmailTemplateListResponse is returned by GET /admin/settings/email-templates. +type EmailTemplateListResponse struct { + Events []EmailTemplateEventOption `json:"events"` + Locales []string `json:"locales"` + Templates []EmailTemplateSummary `json:"templates,omitempty"` + Placeholders []string `json:"placeholders,omitempty"` +} + +// EmailTemplateDetail is returned for a specific event/locale template. +type EmailTemplateDetail struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + IsCustom bool `json:"is_custom,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + Placeholders []string `json:"placeholders,omitempty"` +} + +// UpdateEmailTemplateRequest updates a template override. +type UpdateEmailTemplateRequest struct { + Subject string `json:"subject"` + HTML string `json:"html"` +} + +// PreviewEmailTemplateRequest previews a template without saving it. +type PreviewEmailTemplateRequest struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + Variables map[string]string `json:"variables,omitempty"` +} + +// EmailTemplatePreviewResponse is the rendered preview payload. +type EmailTemplatePreviewResponse struct { + Subject string `json:"subject"` + HTML string `json:"html"` +} + // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // Returns empty slice on empty/invalid input. func ParseCustomMenuItems(raw string) []CustomMenuItem { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 800328bb..406ed819 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1133,9 +1133,15 @@ func (h *GatewayHandler) Usage(c *gin.Context) { // 解析可选的日期范围参数(用于 model_stats 查询) startTime, endTime := h.parseUsageDateRange(c) + days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", "")) + if !ok { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Invalid days, allowed range is 1-90") + return + } // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 usageData := h.buildUsageData(ctx, apiKey.ID) + dailyUsage := h.buildAPIKeyDailyUsage(c, subject.UserID, apiKey.ID, days) // Best-effort: 获取模型统计 var modelStats any @@ -1149,11 +1155,11 @@ func (h *GatewayHandler) Usage(c *gin.Context) { isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits() if isQuotaLimited { - h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats) + h.usageQuotaLimited(c, ctx, apiKey, usageData, dailyUsage, modelStats) return } - h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats) + h.usageUnrestricted(c, ctx, apiKey, subject, usageData, dailyUsage, modelStats) } // parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围 @@ -1211,8 +1217,20 @@ func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin } } +func (h *GatewayHandler) buildAPIKeyDailyUsage(c *gin.Context, userID, apiKeyID int64, days int) any { + if h.usageService == nil { + return nil + } + startTime, endTime := apiKeyDailyUsageRange(days, c.Query("timezone")) + stats, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), userID, apiKeyID, startTime, endTime) + if err != nil { + return nil + } + return stats +} + // usageQuotaLimited 处理 quota_limited 模式的响应 -func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) { +func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, dailyUsage any, modelStats any) { resp := gin.H{ "mode": "quota_limited", "isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired, @@ -1294,6 +1312,9 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, if usageData != nil { resp["usage"] = usageData } + if dailyUsage != nil { + resp["daily_usage"] = dailyUsage + } if modelStats != nil { resp["model_stats"] = modelStats } @@ -1302,7 +1323,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, } // usageUnrestricted 处理 unrestricted 模式的响应(向后兼容) -func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) { +func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, dailyUsage any, modelStats any) { // 订阅模式 if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { resp := gin.H{ @@ -1331,6 +1352,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, if usageData != nil { resp["usage"] = usageData } + if dailyUsage != nil { + resp["daily_usage"] = dailyUsage + } if modelStats != nil { resp["model_stats"] = modelStats } @@ -1356,6 +1380,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, if usageData != nil { resp["usage"] = usageData } + if dailyUsage != nil { + resp["daily_usage"] = dailyUsage + } if modelStats != nil { resp["model_stats"] = modelStats } diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 1bb81190..f3c16f5d 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -266,6 +266,7 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { PaymentSource: req.PaymentSource, OrderType: req.OrderType, PlanID: req.PlanID, + Locale: c.GetHeader("Accept-Language"), }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index c4ba43e4..7413b840 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -1,6 +1,10 @@ package handler import ( + "html" + "net/http" + "strings" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -10,8 +14,9 @@ import ( // SettingHandler 公开设置处理器(无需认证) type SettingHandler struct { - settingService *service.SettingService - version string + settingService *service.SettingService + notificationEmailService *service.NotificationEmailService + version string } // NewSettingHandler 创建公开设置处理器 @@ -22,6 +27,12 @@ func NewSettingHandler(settingService *service.SettingService, version string) * } } +// SetNotificationEmailService attaches the public notification email service without +// changing the constructor signature used by existing tests. +func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) { + h.notificationEmailService = notificationEmailService +} + // GetPublicSettings 获取公开设置 // GET /api/v1/settings/public func (h *SettingHandler) GetPublicSettings(c *gin.Context) { @@ -90,6 +101,27 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { }) } +// UnsubscribeNotificationEmail handles optional notification email opt-outs. +// GET /api/v1/settings/email-unsubscribe?token=... +func (h *SettingHandler) UnsubscribeNotificationEmail(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + token := strings.TrimSpace(c.Query("token")) + if token == "" { + response.BadRequest(c, "token is required") + return + } + result, err := h.notificationEmailService.Unsubscribe(c.Request.Context(), token) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + body := "
You have unsubscribed " + html.EscapeString(result.Email) + " from " + html.EscapeString(result.Event) + " emails.
" + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(body)) +} + func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument { result := make([]dto.LoginAgreementDocument, 0, len(items)) for _, item := range items { diff --git a/backend/internal/handler/totp_handler.go b/backend/internal/handler/totp_handler.go index 5c5eb567..f9151dab 100644 --- a/backend/internal/handler/totp_handler.go +++ b/backend/internal/handler/totp_handler.go @@ -172,7 +172,7 @@ func (h *TotpHandler) SendVerifyCode(c *gin.Context) { return } - if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil { + if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID, c.GetHeader("Accept-Language")); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index b8506154..daa5695d 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -298,6 +298,29 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { return startTime, endTime } +const ( + defaultAPIKeyDailyUsageDays = 30 + maxAPIKeyDailyUsageDays = 90 +) + +func parseAPIKeyDailyUsageDays(raw string) (int, bool) { + if strings.TrimSpace(raw) == "" { + return defaultAPIKeyDailyUsageDays, true + } + days, err := strconv.Atoi(raw) + if err != nil || days <= 0 || days > maxAPIKeyDailyUsageDays { + return 0, false + } + return days, true +} + +func apiKeyDailyUsageRange(days int, userTZ string) (time.Time, time.Time) { + now := timezone.NowInUserLocation(userTZ) + startTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -(days-1)), userTZ) + endTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ) + return startTime, endTime +} + // DashboardStats handles getting user dashboard statistics // GET /api/v1/usage/dashboard/stats func (h *UsageHandler) DashboardStats(c *gin.Context) { @@ -416,3 +439,55 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { response.Success(c, gin.H{"stats": stats}) } + +// GetMyAPIKeyDailyUsage handles getting daily usage details for the current user's API key. +// GET /api/v1/user/api-keys/:id/usage/daily?days=30 +func (h *UsageHandler) GetMyAPIKeyDailyUsage(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + apiKeyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid API key ID") + return + } + + days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", "")) + if !ok { + response.BadRequest(c, "Invalid days, allowed range is 1-90") + return + } + + if h.apiKeyService == nil { + response.InternalError(c, "API key service is not configured") + return + } + + apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), apiKeyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if apiKey.UserID != subject.UserID { + response.Forbidden(c, "Not authorized to access this API key's usage") + return + } + + userTZ := c.Query("timezone") + startTime, endTime := apiKeyDailyUsageRange(days, userTZ) + items, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), subject.UserID, apiKeyID, startTime, endTime) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "items": items, + "days": days, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.AddDate(0, 0, -1).Format("2006-01-02"), + }) +} diff --git a/backend/internal/handler/usage_handler_daily_test.go b/backend/internal/handler/usage_handler_daily_test.go new file mode 100644 index 00000000..36311fac --- /dev/null +++ b/backend/internal/handler/usage_handler_daily_test.go @@ -0,0 +1,195 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dailyUsageRepoStub struct { + service.UsageLogRepository + trend []usagestats.TrendDataPoint + + called bool + startTime time.Time + endTime time.Time + granularity string + userID int64 + apiKeyID int64 +} + +func (s *dailyUsageRepoStub) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + s.called = true + s.startTime = startTime + s.endTime = endTime + s.granularity = granularity + s.userID = userID + s.apiKeyID = apiKeyID + return s.trend, nil +} + +type dailyUsageAPIKeyRepoStub struct { + service.APIKeyRepository + keys map[int64]*service.APIKey +} + +func (s *dailyUsageAPIKeyRepoStub) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { + key, ok := s.keys[id] + if !ok { + return nil, service.ErrAPIKeyNotFound + } + clone := *key + return &clone, nil +} + +func newDailyUsageTestRouter(usageRepo *dailyUsageRepoStub, apiKeyRepo *dailyUsageAPIKeyRepoStub, userID int64) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(usageRepo, nil, nil, nil) + apiKeySvc := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, nil) + handler := NewUsageHandler(usageSvc, apiKeySvc) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: userID}) + c.Next() + }) + router.GET("/user/api-keys/:id/usage/daily", handler.GetMyAPIKeyDailyUsage) + return router +} + +type dailyUsageHandlerResponse struct { + Code int `json:"code"` + Data struct { + Items []usagestats.APIKeyDailyUsagePoint `json:"items"` + Days int `json:"days"` + } `json:"data"` +} + +func TestGetMyAPIKeyDailyUsageRejectsCrossUserAccess(t *testing.T) { + usageRepo := &dailyUsageRepoStub{} + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 99, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=30", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.False(t, usageRepo.called) +} + +func TestGetMyAPIKeyDailyUsageRejectsInvalidDays(t *testing.T) { + for _, path := range []string{ + "/user/api-keys/7/usage/daily?days=0", + "/user/api-keys/7/usage/daily?days=91", + } { + t.Run(path, func(t *testing.T) { + usageRepo := &dailyUsageRepoStub{} + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, path, nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.False(t, usageRepo.called) + }) + } +} + +func TestGetMyAPIKeyDailyUsageReturnsEmptyData(t *testing.T) { + usageRepo := &dailyUsageRepoStub{trend: []usagestats.TrendDataPoint{}} + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var got dailyUsageHandlerResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, 30, got.Data.Days) + require.Empty(t, got.Data.Items) +} + +func TestGetMyAPIKeyDailyUsageAggregatesByDayForOwnedKey(t *testing.T) { + usageRepo := &dailyUsageRepoStub{ + trend: []usagestats.TrendDataPoint{ + { + Date: "2026-05-19", + Requests: 3, + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 4, + CacheReadTokens: 6, + TotalTokens: 40, + Cost: 0.5, + ActualCost: 0.4, + }, + }, + } + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=7", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, usageRepo.called) + require.Equal(t, "day", usageRepo.granularity) + require.Equal(t, int64(42), usageRepo.userID) + require.Equal(t, int64(7), usageRepo.apiKeyID) + require.True(t, usageRepo.startTime.Before(usageRepo.endTime)) + + var got dailyUsageHandlerResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, 7, got.Data.Days) + require.Len(t, got.Data.Items, 1) + require.Equal(t, usagestats.APIKeyDailyUsagePoint{ + Date: "2026-05-19", + Requests: 3, + InputTokens: 10, + OutputTokens: 20, + CacheReadTokens: 6, + CacheWriteTokens: 4, + TotalTokens: 40, + Cost: 0.5, + ActualCost: 0.4, + }, got.Data.Items[0]) +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index f1dbf4e1..95cb1482 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -335,7 +335,7 @@ func (h *UserHandler) SendEmailBindingCode(c *gin.Context) { return } - if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil { + if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email, c.GetHeader("Accept-Language")); err != nil { response.ErrorFrom(c, err) return } @@ -363,7 +363,7 @@ func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) { return } - err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache) + err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache, c.GetHeader("Accept-Language")) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index cb4ab0a4..1b74d873 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -90,8 +90,17 @@ func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService } // ProvideSettingHandler creates SettingHandler with version from BuildInfo -func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler { - return NewSettingHandler(settingService, buildInfo.Version) +func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo, notificationEmailService *service.NotificationEmailService) *SettingHandler { + h := NewSettingHandler(settingService, buildInfo.Version) + h.SetNotificationEmailService(notificationEmailService) + return h +} + +// ProvideAdminSettingHandler creates admin.SettingHandler with notification template APIs. +func ProvideAdminSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService, notificationEmailService *service.NotificationEmailService) *admin.SettingHandler { + h := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService) + h.SetNotificationEmailService(notificationEmailService) + return h } // ProvideHandlers creates the Handlers struct @@ -169,7 +178,7 @@ var ProviderSet = wire.NewSet( admin.NewProxyHandler, admin.NewRedeemHandler, admin.NewPromoHandler, - admin.NewSettingHandler, + ProvideAdminSettingHandler, admin.NewOpsHandler, ProvideSystemHandler, admin.NewSubscriptionHandler, diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go new file mode 100644 index 00000000..8fb82ef4 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go @@ -0,0 +1,719 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// ResponsesToChatCompletionsRequest converts a Responses API request into a +// Chat Completions request for upstreams that only implement +// /v1/chat/completions. +func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) { + if req == nil { + return nil, fmt.Errorf("responses request is nil") + } + + messages, err := responsesInputToChatMessages(req.Instructions, req.Input) + if err != nil { + return nil, err + } + + out := &ChatCompletionsRequest{ + Model: req.Model, + Messages: messages, + MaxCompletionTokens: req.MaxOutputTokens, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + ServiceTier: req.ServiceTier, + } + if req.Reasoning != nil { + out.ReasoningEffort = req.Reasoning.Effort + } + if len(req.Tools) > 0 { + out.Tools = responsesToolsToChatTools(req.Tools) + } + if len(req.ToolChoice) > 0 { + out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice) + } + + return out, nil +} + +func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) { + var messages []ChatMessage + if strings.TrimSpace(instructions) != "" { + content, _ := json.Marshal(instructions) + messages = append(messages, ChatMessage{ + Role: "system", + Content: content, + }) + } + + inputRaw = bytesTrimSpace(inputRaw) + if len(inputRaw) == 0 || string(inputRaw) == "null" { + return messages, nil + } + + var inputText string + if err := json.Unmarshal(inputRaw, &inputText); err == nil { + content, _ := json.Marshal(inputText) + messages = append(messages, ChatMessage{ + Role: "user", + Content: content, + }) + return messages, nil + } + + var rawItems []json.RawMessage + if err := json.Unmarshal(inputRaw, &rawItems); err != nil { + return nil, fmt.Errorf("parse responses input: %w", err) + } + + for _, raw := range rawItems { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + continue + } + + var item map[string]json.RawMessage + if err := json.Unmarshal(raw, &item); err != nil { + var text string + if textErr := json.Unmarshal(raw, &text); textErr == nil { + content, _ := json.Marshal(text) + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + } + return nil, fmt.Errorf("parse responses input item: %w", err) + } + + role := rawString(item["role"]) + itemType := rawString(item["type"]) + switch itemType { + case "function_call": + arguments := rawString(item["arguments"]) + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + messages = append(messages, ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{{ + ID: rawString(item["call_id"]), + Type: "function", + Function: ChatFunctionCall{ + Name: rawString(item["name"]), + Arguments: arguments, + }, + }}, + }) + continue + case "function_call_output": + content, _ := json.Marshal(rawString(item["output"])) + messages = append(messages, ChatMessage{ + Role: "tool", + ToolCallID: rawString(item["call_id"]), + Content: content, + }) + continue + case "input_text", "text": + content, _ := json.Marshal(rawString(item["text"])) + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + case "input_image": + content, err := chatContentFromSingleResponsesPart(itemType, item) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + } + + if role == "" { + role = "user" + } + content := item["content"] + if len(bytesTrimSpace(content)) == 0 { + if text := rawString(item["text"]); text != "" { + content, _ = json.Marshal(text) + } + } + chatContent, err := responsesContentToChatContent(content, role) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{ + Role: role, + Content: chatContent, + }) + } + + return messages, nil +} + +func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + empty, _ := json.Marshal("") + return empty, nil + } + + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return raw, nil + } + + var rawParts []json.RawMessage + if err := json.Unmarshal(raw, &rawParts); err == nil { + return responsesContentPartsToChatContent(rawParts, role) + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err == nil { + return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj) + } + + return raw, nil +} + +func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) { + var textParts []string + var chatParts []ChatContentPart + hasNonText := false + + for _, rawPart := range rawParts { + var part map[string]json.RawMessage + if err := json.Unmarshal(rawPart, &part); err != nil { + continue + } + partType := rawString(part["type"]) + switch partType { + case "input_text", "output_text", "text", "": + text := rawString(part["text"]) + if text == "" { + continue + } + textParts = append(textParts, text) + chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text}) + case "input_image", "image_url": + imageURL := rawString(part["image_url"]) + if imageURL == "" { + imageURL = rawNestedString(part["image_url"], "url") + } + if imageURL == "" { + continue + } + hasNonText = true + chatParts = append(chatParts, ChatContentPart{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: imageURL}, + }) + } + } + + if !hasNonText { + joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) + return joined, nil + } + if role != "user" { + joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) + return joined, nil + } + if len(chatParts) == 0 { + empty, _ := json.Marshal("") + return empty, nil + } + return json.Marshal(chatParts) +} + +func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) { + switch partType { + case "input_image", "image_url": + imageURL := rawString(part["image_url"]) + if imageURL == "" { + imageURL = rawNestedString(part["image_url"], "url") + } + return json.Marshal([]ChatContentPart{{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: imageURL}, + }}) + default: + return json.Marshal(rawString(part["text"])) + } +} + +func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool { + out := make([]ChatTool, 0, len(tools)) + for _, tool := range tools { + if tool.Type != "function" { + continue + } + out = append(out, ChatTool{ + Type: "function", + Function: &ChatFunction{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + Strict: tool.Strict, + }, + }) + } + return out +} + +func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage { + var choice map[string]json.RawMessage + if err := json.Unmarshal(raw, &choice); err != nil { + return raw + } + if rawString(choice["type"]) != "function" { + return raw + } + name := rawString(choice["name"]) + if name == "" { + name = rawNestedString(choice["function"], "name") + } + if name == "" { + return raw + } + out, err := json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{ + "name": name, + }, + }) + if err != nil { + return raw + } + return out +} + +// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions +// response into a Responses API response. +func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse { + id := "" + if resp != nil { + id = resp.ID + } + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: model, + Status: "completed", + } + if resp == nil { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + return out + } + if out.Model == "" { + out.Model = resp.Model + } + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + out.Output = chatMessageToResponsesOutput(choice.Message) + if choice.FinishReason == "length" { + out.Status = "incomplete" + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + } + if len(out.Output) == 0 { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + } + if resp.Usage != nil { + out.Usage = ChatUsageToResponsesUsage(resp.Usage) + } + return out +} + +func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput { + var outputs []ResponsesOutput + if message.ReasoningContent != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: message.ReasoningContent, + }}, + }) + } + + text := chatMessageContentText(message.Content) + if text != "" || len(message.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: text, + }}, + Status: "completed", + }) + } + + for _, toolCall := range message.ToolCalls { + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + + return outputs +} + +func emptyResponsesMessageOutput() ResponsesOutput { + return ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + } +} + +func chatMessageContentText(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return text + } + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, part := range parts { + if part.Type == "text" && part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses +// usage shape. +func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage { + if usage == nil { + return nil + } + out := &ResponsesUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + } + if out.TotalTokens == 0 { + out.TotalTokens = out.InputTokens + out.OutputTokens + } + if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 { + out.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: usage.PromptTokensDetails.CachedTokens, + } + } + return out +} + +// ChatCompletionsToResponsesStreamState tracks state while converting Chat +// Completions SSE chunks into Responses SSE events. +type ChatCompletionsToResponsesStreamState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + CreatedSent bool + CompletedSent bool + + MessageItemID string + Text strings.Builder + Reasoning strings.Builder + ToolCalls map[int]*ChatToolCall + + FinishReason string + Usage *ResponsesUsage +} + +// NewChatCompletionsToResponsesStreamState returns an initialized stream state. +func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState { + return &ChatCompletionsToResponsesStreamState{ + ResponseID: generateResponsesID(), + Model: model, + Created: time.Now().Unix(), + ToolCalls: make(map[int]*ChatToolCall), + } +} + +// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream +// chunk into zero or more Responses stream events. +func ChatCompletionsChunkToResponsesEvents( + chunk *ChatCompletionsChunk, + state *ChatCompletionsToResponsesStreamState, +) []ResponsesStreamEvent { + if chunk == nil || state == nil { + return nil + } + if chunk.ID != "" { + state.ResponseID = chunk.ID + } + if state.Model == "" && chunk.Model != "" { + state.Model = chunk.Model + } + if chunk.Usage != nil { + state.Usage = ChatUsageToResponsesUsage(chunk.Usage) + } + + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + + for _, choice := range chunk.Choices { + if choice.Delta.Content != nil { + events = append(events, ensureChatToResponsesMessageItem(state)...) + _, _ = state.Text.WriteString(*choice.Delta.Content) + events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Delta: *choice.Delta.Content, + ItemID: state.MessageItemID, + })) + } + if choice.Delta.ReasoningContent != nil { + _, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent) + events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + SummaryIndex: 0, + Delta: *choice.Delta.ReasoningContent, + })) + } + for _, toolCall := range choice.Delta.ToolCalls { + idx := 0 + if toolCall.Index != nil { + idx = *toolCall.Index + } + stored, ok := state.ToolCalls[idx] + if !ok { + copyCall := toolCall + if copyCall.ID == "" { + copyCall.ID = generateItemID() + } + copyCall.Type = "function" + state.ToolCalls[idx] = ©Call + stored = ©Call + events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Item: &ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: stored.ID, + Name: stored.Function.Name, + Status: "in_progress", + }, + })) + } else { + if toolCall.ID != "" { + stored.ID = toolCall.ID + } + if toolCall.Function.Name != "" { + stored.Function.Name = toolCall.Function.Name + } + } + if toolCall.Function.Arguments != "" { + stored.Function.Arguments += toolCall.Function.Arguments + events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Delta: toolCall.Function.Arguments, + CallID: stored.ID, + Name: stored.Function.Name, + })) + } + } + if choice.FinishReason != nil && *choice.FinishReason != "" { + state.FinishReason = *choice.FinishReason + } + } + + return events +} + +// FinalizeChatCompletionsResponsesStream emits terminal Responses events. +func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state == nil || state.CompletedSent { + return nil + } + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + if state.MessageItemID != "" { + events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Text: state.Text.String(), + ItemID: state.MessageItemID, + })) + events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "completed", + }, + })) + } + + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + if state.FinishReason == "length" { + status = "incomplete" + incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + state.CompletedSent = true + events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: state.chatOutput(), + Usage: state.Usage, + IncompleteDetails: incompleteDetails, + }, + })) + return events +} + +func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.CreatedSent { + return nil + } + state.CreatedSent = true + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + })} +} + +func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.MessageItemID != "" { + return nil + } + state.MessageItemID = generateItemID() + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "in_progress", + }, + })} +} + +func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput { + var outputs []ResponsesOutput + if state.Reasoning.Len() > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: state.Reasoning.String(), + }}, + }) + } + if state.MessageItemID != "" || len(state.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: nonEmpty(state.MessageItemID, generateItemID()), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: state.Text.String(), + }}, + Status: "completed", + }) + } + for i := 0; i < len(state.ToolCalls); i++ { + toolCall, ok := state.ToolCalls[i] + if !ok || toolCall == nil { + continue + } + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + return outputs +} + +func chatToResponsesEvent( + state *ChatCompletionsToResponsesStreamState, + eventType string, + template *ResponsesStreamEvent, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func rawString(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + return "" +} + +func rawNestedString(raw json.RawMessage, key string) string { + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + return "" + } + return rawString(obj[key]) +} + +func bytesTrimSpace(raw json.RawMessage) json.RawMessage { + return json.RawMessage(strings.TrimSpace(string(raw))) +} + +func nonEmpty(value, fallback string) string { + if value != "" { + return value + } + return fallback +} diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index dd8fe566..ae3886d6 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -30,6 +30,17 @@ var CodexOfficialClientOriginatorPrefixes = []string{ "codex ", } +// IsBrowserUserAgent 判断 User-Agent 是否来自浏览器(Chrome/Firefox/Safari/Edge/Opera 等)。 +// 所有现代浏览器的 UA 均以 "Mozilla/" 作为前缀,CLI 工具(codex/claude/curl/postman/python-requests 等)不会。 +// 该判定用于避免 Cloudflare 对浏览器型 UA 在 OpenAI 上游接口上触发 JS 质询。 +func IsBrowserUserAgent(userAgent string) bool { + ua := strings.TrimSpace(userAgent) + if ua == "" { + return false + } + return strings.HasPrefix(strings.ToLower(ua), "mozilla/") +} + // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request func IsCodexCLIRequest(userAgent string) bool { ua := normalizeCodexClientHeader(userAgent) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 39283d22..5307389d 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -198,6 +198,19 @@ type APIKeyUsageTrendPoint struct { Tokens int64 `json:"tokens"` } +// APIKeyDailyUsagePoint represents one day of usage for a single API key. +type APIKeyDailyUsagePoint struct { + Date string `json:"date"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 +} + // UserDashboardStats 用户仪表盘统计 type UserDashboardStats struct { // API Key 统计 diff --git a/backend/internal/repository/aes_encryptor_test.go b/backend/internal/repository/aes_encryptor_test.go new file mode 100644 index 00000000..25bff622 --- /dev/null +++ b/backend/internal/repository/aes_encryptor_test.go @@ -0,0 +1,219 @@ +//go:build unit + +package repository + +import ( + "encoding/base64" + "encoding/hex" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ── 测试辅助 ───────────────────────────────────────────────────────────────── + +// aesHexKey 构造一个全填充为 b 的 n 字节密钥并以 hex 编码返回。 +func aesHexKey(n int, b byte) string { + raw := make([]byte, n) + for i := range raw { + raw[i] = b + } + return hex.EncodeToString(raw) +} + +// aesTestCfg 用给定 hex 密钥字符串构造最小 Config。 +func aesTestCfg(keyHex string) *config.Config { + return &config.Config{ + Totp: config.TotpConfig{EncryptionKey: keyHex}, + } +} + +// aesEncryptor 创建一个持有合法 32 字节密钥的加密器,测试失败时立即终止。 +func aesEncryptor(t *testing.T) *AESEncryptor { + t.Helper() + enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x42))) + require.NoError(t, err) + require.NotNil(t, enc) + return enc.(*AESEncryptor) +} + +// ── NewAESEncryptor ────────────────────────────────────────────────────────── + +func TestNewAESEncryptor_ValidKey32Bytes(t *testing.T) { + enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x01))) + require.NoError(t, err) + require.NotNil(t, enc) +} + +// 16 / 24 字节密钥在 AES 体系内合法,但本实现仅接受 AES-256(32 字节)。 +func TestNewAESEncryptor_WrongKeyLength(t *testing.T) { + tests := []struct { + name string + keySize int + }{ + {"16_bytes_AES128", 16}, + {"24_bytes_AES192", 24}, + {"1_byte", 1}, + {"31_bytes", 31}, + {"33_bytes", 33}, + {"64_bytes", 64}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewAESEncryptor(aesTestCfg(aesHexKey(tt.keySize, 0x00))) + require.Error(t, err) + assert.Contains(t, err.Error(), "32 bytes") + }) + } +} + +// "配置缺失"场景:空字符串与非法 hex 编码。 +func TestNewAESEncryptor_MissingOrInvalidConfig(t *testing.T) { + tests := []struct { + name string + keyHex string + wantContain string + }{ + {"empty_key", "", "32 bytes"}, + {"invalid_hex_odd_length", "abcde", "invalid totp encryption key"}, + {"invalid_hex_chars", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", "invalid totp encryption key"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewAESEncryptor(aesTestCfg(tt.keyHex)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantContain) + }) + } +} + +// ── 加解密往返(Roundtrip)─────────────────────────────────────────────────── + +func TestAESEncryptor_RoundTrip(t *testing.T) { + enc := aesEncryptor(t) + + tests := []struct { + name string + plaintext string + }{ + {"ascii", "Hello, Sub2API!"}, + {"chinese_multibyte", "你好,世界!这是多字节 UTF-8 文本。"}, + {"empty_string", ""}, + {"long_string_gt_1KB", strings.Repeat("x", 2048)}, + {"special_chars", "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ct, err := enc.Encrypt(tt.plaintext) + require.NoError(t, err) + require.NotEmpty(t, ct, "密文不应为空(即便明文为空字符串)") + + got, err := enc.Decrypt(ct) + require.NoError(t, err) + assert.Equal(t, tt.plaintext, got) + }) + } +} + +// ── IV/Nonce 随机性 ────────────────────────────────────────────────────────── + +func TestAESEncryptor_Encrypt_NonceRandomness(t *testing.T) { + enc := aesEncryptor(t) + const iterations = 30 + plaintext := "same plaintext for every iteration" + + seen := make(map[string]struct{}, iterations) + for i := 0; i < iterations; i++ { + ct, err := enc.Encrypt(plaintext) + require.NoError(t, err) + seen[ct] = struct{}{} + } + + // 30 次加密相同明文,每次因随机 Nonce 应产生不同密文。 + assert.Len(t, seen, iterations, + "每次加密应因随机 Nonce 产生唯一密文,共 %d 次", iterations) +} + +// ── Decrypt 错误路径 ────────────────────────────────────────────────────────── + +func TestAESDecrypt_InvalidBase64(t *testing.T) { + enc := aesEncryptor(t) + _, err := enc.Decrypt("!!!not-valid-base64!!!") + require.Error(t, err) + assert.Contains(t, err.Error(), "decode base64") +} + +func TestAESDecrypt_TooShort(t *testing.T) { + enc := aesEncryptor(t) + // GCM Nonce 为 12 字节;仅提供 2 字节,必然短于 NonceSize。 + short := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02}) + _, err := enc.Decrypt(short) + require.Error(t, err) + assert.Contains(t, err.Error(), "too short") +} + +func TestAESDecrypt_TamperedCiphertext(t *testing.T) { + enc := aesEncryptor(t) + + ct, err := enc.Encrypt("sensitive payload") + require.NoError(t, err) + + raw, err := base64.StdEncoding.DecodeString(ct) + require.NoError(t, err) + + // Nonce 占前 12 字节;翻转其后第一个字节(密文体)。 + raw[12] ^= 0xFF + _, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw)) + require.Error(t, err, "篡改密文体后解密应失败") +} + +func TestAESDecrypt_TamperedTag(t *testing.T) { + enc := aesEncryptor(t) + + ct, err := enc.Encrypt("sensitive payload") + require.NoError(t, err) + + raw, err := base64.StdEncoding.DecodeString(ct) + require.NoError(t, err) + + // GCM 认证标签占最后 16 字节;翻转最后一个字节。 + raw[len(raw)-1] ^= 0xFF + _, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw)) + require.Error(t, err, "篡改 GCM 标签后解密应失败") +} + +// ── 跨实例(Cross-instance)────────────────────────────────────────────────── + +func TestAESEncryptor_CrossInstance_SameKey_CanDecrypt(t *testing.T) { + keyHex := aesHexKey(32, 0xDE) + + enc1, err := NewAESEncryptor(aesTestCfg(keyHex)) + require.NoError(t, err) + enc2, err := NewAESEncryptor(aesTestCfg(keyHex)) + require.NoError(t, err) + + plaintext := "cross-instance roundtrip" + ct, err := enc1.Encrypt(plaintext) + require.NoError(t, err) + + got, err := enc2.Decrypt(ct) + require.NoError(t, err) + assert.Equal(t, plaintext, got, "相同密钥构造的两个实例应可互相解密") +} + +func TestAESEncryptor_CrossInstance_DifferentKey_CannotDecrypt(t *testing.T) { + enc1, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xAA))) + require.NoError(t, err) + enc2, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xBB))) + require.NoError(t, err) + + ct, err := enc1.Encrypt("secret message") + require.NoError(t, err) + + _, err = enc2.Decrypt(ct) + require.Error(t, err, "不同密钥的实例不应能解密对方的密文") +} diff --git a/backend/internal/repository/allowed_groups_contract_integration_test.go b/backend/internal/repository/allowed_groups_contract_integration_test.go index b0af0d54..5aa95a91 100644 --- a/backend/internal/repository/allowed_groups_contract_integration_test.go +++ b/backend/internal/repository/allowed_groups_contract_integration_test.go @@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te require.NotContains(t, u2After.AllowedGroups, targetGroup.ID) } -func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) { +func TestGroupRepository_DeleteCascade_PreservesApiKeyGroupID(t *testing.T) { ctx := context.Background() tx := testEntTx(t) entClient := tx.Client() @@ -138,8 +138,10 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID) require.Contains(t, uAfter.AllowedGroups, otherGroup.ID) - // API keys bound to the deleted group should have group_id cleared. + // API keys keep their group_id so auth can reject keys bound to a deleted group. keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID) require.NoError(t, err) - require.Nil(t, keyAfter.GroupID) + require.NotNil(t, keyAfter.GroupID) + require.Equal(t, targetGroup.ID, *keyAfter.GroupID) + require.Nil(t, keyAfter.Group) } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 9b6377bc..9c3b2010 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -9,7 +9,6 @@ import ( "strings" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -94,9 +93,13 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group if err != nil { return nil, err } - total, active, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = total - out.ActiveAccountCount = active + counts, err := r.loadAccountCounts(ctx, []int64{out.ID}) + if err == nil { + c := counts[out.ID] + out.AccountCount = c.Total + out.ActiveAccountCount = c.Active + out.RateLimitedAccountCount = c.RateLimited + } return out, nil } @@ -538,15 +541,12 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) { var rateLimited int64 err = scanSingleRow(ctx, r.sql, - `SELECT COUNT(*), - COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true), - COUNT(*) FILTER (WHERE a.status = 'active' AND ( - a.rate_limit_reset_at > NOW() OR - a.overload_until > NOW() OR - a.temp_unschedulable_until > NOW() - )) + fmt.Sprintf(`SELECT + COUNT(*) FILTER (WHERE a.deleted_at IS NULL), + COUNT(*) FILTER (WHERE %s), + COUNT(*) FILTER (WHERE %s) FROM account_groups ag JOIN accounts a ON a.id = ag.account_id - WHERE ag.group_id = $1`, + WHERE ag.group_id = $1`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL), []any{groupID}, &total, &active, &rateLimited) return } @@ -636,28 +636,18 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, } } - // 2. Clear group_id for api keys bound to this group. - // 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。 - // 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。 - if _, err := txClient.APIKey.Update(). - Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()). - ClearGroupID(). - Save(ctx); err != nil { - return nil, err - } - - // 3. Remove the group id from user_allowed_groups join table. + // 2. Remove the group id from user_allowed_groups join table. // Legacy users.allowed_groups 列已弃用,不再同步。 if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil { return nil, err } - // 4. Delete account_groups join rows. + // 3. Delete account_groups join rows. if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil { return nil, err } - // 5. Soft-delete group itself. + // 4. Soft-delete group itself. if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil { return nil, err } @@ -680,6 +670,28 @@ type groupAccountCounts struct { RateLimited int64 } +const ( + // 分组页的"可用"账号数必须与账号仓储的 ListSchedulableByGroupID 过滤口径一致。 + groupAccountAvailableSQL = `a.deleted_at IS NULL + AND a.status = 'active' + AND a.schedulable = true + AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE) + AND (a.rate_limit_reset_at IS NULL OR a.rate_limit_reset_at <= NOW()) + AND (a.overload_until IS NULL OR a.overload_until <= NOW()) + AND (a.temp_unschedulable_until IS NULL OR a.temp_unschedulable_until <= NOW())` + + // 这里沿用历史字段名 RateLimitedAccountCount,但统计的是会让账号暂时退出调度的时间窗口。 + groupAccountTemporarilyLimitedSQL = `a.deleted_at IS NULL + AND a.status = 'active' + AND a.schedulable = true + AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE) + AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )` +) + func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) { counts = make(map[int64]groupAccountCounts, len(groupIDs)) if len(groupIDs) == 0 { @@ -688,18 +700,14 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 rows, err := r.sql.QueryContext( ctx, - `SELECT ag.group_id, - COUNT(*) AS total, - COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active, - COUNT(*) FILTER (WHERE a.status = 'active' AND ( - a.rate_limit_reset_at > NOW() OR - a.overload_until > NOW() OR - a.temp_unschedulable_until > NOW() - )) AS rate_limited + fmt.Sprintf(`SELECT ag.group_id, + COUNT(*) FILTER (WHERE a.deleted_at IS NULL) AS total, + COUNT(*) FILTER (WHERE %s) AS active, + COUNT(*) FILTER (WHERE %s) AS rate_limited FROM account_groups ag JOIN accounts a ON a.id = ag.account_id WHERE ag.group_id = ANY($1) - GROUP BY ag.group_id`, + GROUP BY ag.group_id`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL), pq.Array(groupIDs), ) if err != nil { diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index f91dae43..68183b2b 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -651,6 +651,164 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { s.Require().Zero(count) } +// TestListWithFilters_ActiveAccountCount_LessThanTotal 验证 ActiveAccountCount 正确区分可用与不可用账号。 +// 当分组内存在 disabled 或 schedulable=false 的账号时,ActiveAccountCount 必须小于 AccountCount, +// 且与 GetAccountCount 返回的 active 值一致。 +func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() { + g := &service.Group{ + Name: "g-mixed-status", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g)) + + insertAccount := func(name, status string, schedulable bool) int64 { + var id int64 + s.Require().NoError(scanSingleRow( + s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, status, schedulable) VALUES ($1, $2, $3, $4, $5) RETURNING id", + []any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status, schedulable}, + &id, + )) + return id + } + link := func(accountID int64, priority int) { + _, err := s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + accountID, g.ID, priority) + s.Require().NoError(err) + } + + // account 1: active + schedulable → counts toward both total and active + link(insertAccount("acc-active-sched", service.StatusActive, true), 1) + // account 2: disabled → counts toward total only + link(insertAccount("acc-disabled", service.StatusDisabled, true), 2) + // account 3: active + not schedulable → counts toward total only + link(insertAccount("acc-unschedulable", service.StatusActive, false), 3) + + // --- ListWithFilters path --- + isExclusive := false + groups, _, err := s.repo.ListWithFilters(s.ctx, + pagination.PaginationParams{Page: 1, PageSize: 100}, + service.PlatformAnthropic, service.StatusActive, "", &isExclusive) + s.Require().NoError(err) + + var found *service.Group + for i := range groups { + if groups[i].ID == g.ID { + found = &groups[i] + break + } + } + s.Require().NotNil(found, "created group must appear in ListWithFilters result") + s.Assert().Equal(int64(3), found.AccountCount, "AccountCount must count all 3 accounts") + s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must count only the active+schedulable account") + + // --- GetAccountCount must return identical values --- + total, active, err := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount") + s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount") +} + +// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。 +// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池, +// 因此 ActiveAccountCount 必须与真实调度查询口径一致。 +func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() { + g := &service.Group{ + Name: "g-rate-limited", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g)) + + var normalID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", + []any{"acc-normal", service.PlatformAnthropic, service.AccountTypeOAuth}, + &normalID)) + + var rateLimitedID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, rate_limit_reset_at) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth}, + &rateLimitedID)) + + var overloadedID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth}, + &overloadedID)) + + var tempUnschedulableID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth}, + &tempUnschedulableID)) + + var expiredID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id", + []any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth}, + &expiredID)) + + _, err := s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + normalID, g.ID, 1) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + rateLimitedID, g.ID, 2) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + overloadedID, g.ID, 3) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + tempUnschedulableID, g.ID, 4) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + expiredID, g.ID, 5) + s.Require().NoError(err) + + isExclusive := false + groups, _, err := s.repo.ListWithFilters(s.ctx, + pagination.PaginationParams{Page: 1, PageSize: 100}, + service.PlatformAnthropic, service.StatusActive, "", &isExclusive) + s.Require().NoError(err) + + var found *service.Group + for i := range groups { + if groups[i].ID == g.ID { + found = &groups[i] + break + } + } + s.Require().NotNil(found, "created group must appear in ListWithFilters result") + s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts") + s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts") + s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts") + + total, active, err := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount") + s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount") + + detail, err := s.repo.GetByID(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters") + s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters") + s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters") +} + // --- DeleteAccountGroupsByGroupID --- func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { diff --git a/backend/internal/repository/group_repo_sort_integration_test.go b/backend/internal/repository/group_repo_sort_integration_test.go index 85b2efcc..39e1ec78 100644 --- a/backend/internal/repository/group_repo_sort_integration_test.go +++ b/backend/internal/repository/group_repo_sort_integration_test.go @@ -7,6 +7,67 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) +// TestListWithAccountCountSort_AttachesActiveCount 验证通过 account_count 排序时, +// ActiveAccountCount 与 AccountCount 都被正确附加到返回结果中, +// 且排序基于 total 账号数而非 active 账号数。 +func (s *GroupRepoSuite) TestListWithAccountCountSort_AttachesActiveCount() { + // Group A: 2 total, 1 active (1 disabled account) + gA := &service.Group{Name: "sort-count-a", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard} + // Group B: 1 total, 1 active + gB := &service.Group{Name: "sort-count-b", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard} + s.Require().NoError(s.repo.Create(s.ctx, gA)) + s.Require().NoError(s.repo.Create(s.ctx, gB)) + + insertAccount := func(name, status string) int64 { + var id int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, status) VALUES ($1, $2, $3, $4) RETURNING id", + []any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status}, + &id)) + return id + } + link := func(accountID, groupID int64, priority int) { + _, err := s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + accountID, groupID, priority) + s.Require().NoError(err) + } + + // gA: 1 active + 1 disabled → total=2, active=1 + link(insertAccount("sa-active", service.StatusActive), gA.ID, 1) + link(insertAccount("sa-disabled", service.StatusDisabled), gA.ID, 2) + // gB: 1 active → total=1, active=1 + link(insertAccount("sb-active", service.StatusActive), gB.ID, 1) + + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, PageSize: 100, SortBy: "account_count", SortOrder: "desc", + }, service.PlatformAnthropic, service.StatusActive, "", nil) + s.Require().NoError(err) + + byID := make(map[int64]service.Group, len(groups)) + for _, g := range groups { + byID[g.ID] = g + } + + s.Require().Contains(byID, gA.ID, "gA must appear in results") + s.Require().Contains(byID, gB.ID, "gB must appear in results") + + cA := byID[gA.ID] + s.Assert().Equal(int64(2), cA.AccountCount, "gA AccountCount must be 2") + s.Assert().Equal(int64(1), cA.ActiveAccountCount, "gA ActiveAccountCount must be 1") + + cB := byID[gB.ID] + s.Assert().Equal(int64(1), cB.AccountCount, "gB AccountCount must be 1") + s.Assert().Equal(int64(1), cB.ActiveAccountCount, "gB ActiveAccountCount must be 1") + + // Sort is by total (not active): gA (total=2) must rank higher than gB (total=1) in desc order + indexByID := make(map[int64]int, len(groups)) + for i, g := range groups { + indexByID[g.ID] = i + } + s.Assert().Less(indexByID[gA.ID], indexByID[gB.ID], "gA (total=2) must rank above gB (total=1) with account_count desc") +} + func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() { g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20} g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 3770d585..65757b62 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -833,6 +833,7 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": true, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": true, + "openai_codex_user_agent": "", "openai_fast_policy_settings": { "rules": [] }, @@ -1058,6 +1059,7 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": false, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": false, + "openai_codex_user_agent": "", "openai_fast_policy_settings": { "rules": [] }, diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index c15f534e..ee439cda 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -109,6 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti AbortWithError(c, 401, "USER_INACTIVE", "User account is not active") return } + if abortIfAPIKeyGroupUnavailable(c, apiKey) { + return + } // ── 4. SimpleMode → early return ───────────────────────────── @@ -251,3 +254,26 @@ func setGroupContext(c *gin.Context, group *service.Group) { ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group) c.Request = c.Request.WithContext(ctx) } + +func abortIfAPIKeyGroupUnavailable(c *gin.Context, apiKey *service.APIKey) bool { + code, message, ok := validateAPIKeyGroupAvailable(apiKey) + if ok { + return false + } + AbortWithError(c, 403, code, message) + return true +} + +func validateAPIKeyGroupAvailable(apiKey *service.APIKey) (string, string, bool) { + if apiKey == nil || apiKey.GroupID == nil { + return "", "", true + } + group := apiKey.Group + if group == nil || strings.EqualFold(group.Status, "deleted") { + return "GROUP_DELETED", "API Key 所属分组已删除", false + } + if !group.IsActive() { + return "GROUP_DISABLED", "API Key 所属分组已停用", false + } + return "", "", true +} diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 84d93edc..3ed71f71 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -54,6 +54,10 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 401, "User account is not active") return } + if _, message, ok := validateAPIKeyGroupAvailable(apiKey); !ok { + abortWithGoogleError(c, 403, message) + return + } // 简易模式:跳过余额和订阅检查 if cfg.RunMode == config.RunModeSimple { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index d6760d8d..a00f70c7 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -300,6 +300,104 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(101) + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + + tests := []struct { + name string + group *service.Group + wantStatus int + wantCode string + }{ + { + name: "active group passes", + group: &service.Group{ + ID: groupID, + Name: "active", + Status: service.StatusActive, + Platform: service.PlatformAnthropic, + Hydrated: true, + }, + wantStatus: http.StatusOK, + }, + { + name: "disabled group is forbidden", + group: &service.Group{ + ID: groupID, + Name: "disabled", + Status: service.StatusDisabled, + Platform: service.PlatformAnthropic, + Hydrated: true, + }, + wantStatus: http.StatusForbidden, + wantCode: "GROUP_DISABLED", + }, + { + name: "deleted status group is forbidden", + group: &service.Group{ + ID: groupID, + Name: "deleted", + Status: "deleted", + Platform: service.PlatformAnthropic, + Hydrated: true, + }, + wantStatus: http.StatusForbidden, + wantCode: "GROUP_DELETED", + }, + { + name: "missing group edge is forbidden", + group: nil, + wantStatus: http.StatusForbidden, + wantCode: "GROUP_DELETED", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + GroupID: &groupID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: tt.group, + } + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, tt.wantStatus, w.Code) + if tt.wantCode != "" { + require.Contains(t, w.Body.String(), tt.wantCode) + } + }) + } +} + func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index b8af9cc5..65eed1ee 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -422,6 +422,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { adminSettings.PUT("", h.Admin.Setting.UpdateSettings) adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection) adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) + adminSettings.GET("/email-templates", h.Admin.Setting.ListEmailTemplates) + adminSettings.POST("/email-template-preview", h.Admin.Setting.PreviewEmailTemplate) + adminSettings.GET("/email-templates/:event/:locale", h.Admin.Setting.GetEmailTemplate) + adminSettings.PUT("/email-templates/:event/:locale", h.Admin.Setting.UpdateEmailTemplate) + adminSettings.POST("/email-templates/:event/:locale/restore-official", h.Admin.Setting.RestoreOfficialEmailTemplate) // Admin API Key 管理 adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 19d0fd2a..2c44a2b3 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -214,6 +214,7 @@ func RegisterAuthRoutes( settings := v1.Group("/settings") { settings.GET("/public", h.Setting.GetPublicSettings) + settings.GET("/email-unsubscribe", h.Setting.UnsubscribeNotificationEmail) } // 需要认证的当前用户信息 diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index 9976954c..e79d3ee3 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -31,6 +31,7 @@ func RegisterUserRoutes( user.POST("/account-bindings/email", h.User.BindEmailIdentity) user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) + user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index a9492a1d..c1da92d1 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -244,6 +244,21 @@ func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSor return nil } +type deleteGroupAPIKeyRepoStub struct { + apiKeyRepoStubForGroupUpdate + keys []string + listErr error + listGroupIDs []int64 +} + +func (s *deleteGroupAPIKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + s.listGroupIDs = append(s.listGroupIDs, groupID) + if s.listErr != nil { + return nil, s.listErr + } + return s.keys, nil +} + type proxyRepoStub struct { deleteErr error countErr error @@ -500,6 +515,23 @@ func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) { }, calls) } +func TestAdminService_DeleteGroup_InvalidatesAuthCacheForBoundKeys(t *testing.T) { + repo := &groupRepoStub{} + apiKeyRepo := &deleteGroupAPIKeyRepoStub{keys: []string{"k1", "k2"}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + groupRepo: repo, + apiKeyRepo: apiKeyRepo, + authCacheInvalidator: invalidator, + } + + err := svc.DeleteGroup(context.Background(), 5) + require.NoError(t, err) + require.Equal(t, []int64{5}, repo.deleteCalls) + require.Equal(t, []int64{5}, apiKeyRepo.listGroupIDs) + require.Equal(t, []string{"k1", "k2"}, invalidator.keys) +} + func TestAdminService_DeleteGroup_NotFound(t *testing.T) { repo := &groupRepoStub{deleteErr: ErrGroupNotFound} svc := &adminServiceImpl{groupRepo: repo} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 877888b1..c752ce28 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs +const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks type apiKeyAuthCacheConfig struct { l1Size int diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go index 78f1185d..84f61d78 100644 --- a/backend/internal/service/auth_email_binding.go +++ b/backend/internal/service/auth_email_binding.go @@ -94,7 +94,7 @@ func (s *AuthService) BindEmailIdentity( } // SendEmailIdentityBindCode sends a verification code for authenticated email binding flows. -func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error { +func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string, locale ...string) error { if s == nil { return ErrServiceUnavailable } @@ -128,7 +128,7 @@ func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int6 if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } - return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName) + return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName, firstEmailLocale(locale)) } func normalizeEmailForIdentityBinding(email string) (string, error) { diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index 3478fda5..cf0be652 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -28,7 +28,7 @@ func normalizeOAuthSignupSource(signupSource string) string { // SendPendingOAuthVerifyCode sends a local verification code for pending OAuth // account-creation flows without relying on the public registration gate. -func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) { +func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) { email = strings.TrimSpace(strings.ToLower(email)) if email == "" { return nil, ErrEmailVerifyRequired @@ -47,7 +47,7 @@ func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email stri if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } - if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil { + if err := s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)); err != nil { return nil, err } return &SendVerifyCodeResult{ diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index ce2b3fa3..4e5b7b94 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -273,7 +273,7 @@ type SendVerifyCodeResult struct { } // SendVerifyCode 发送邮箱验证码(同步方式) -func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { +func (s *AuthService) SendVerifyCode(ctx context.Context, email string, locale ...string) error { // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled @@ -307,11 +307,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { siteName = s.settingService.GetSiteName(ctx) } - return s.emailService.SendVerifyCode(ctx, email, siteName) + return s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)) } // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 -func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { +func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) { logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册(默认关闭) @@ -352,7 +352,7 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S // 异步发送 logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email) - if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { + if err := s.emailQueueService.EnqueueVerifyCode(email, siteName, firstEmailLocale(locale)); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } @@ -1251,7 +1251,7 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB // RequestPasswordReset 请求密码重置(同步发送) // Security: Returns the same response regardless of whether the email exists (prevent user enumeration) -func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error { +func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string, locale ...string) error { if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } @@ -1264,7 +1264,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB return nil // Silent success to prevent enumeration } - if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { + if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err) return nil // Silent success to prevent enumeration } @@ -1275,7 +1275,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB // RequestPasswordResetAsync 异步请求密码重置(队列发送) // Security: Returns the same response regardless of whether the email exists (prevent user enumeration) -func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error { +func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string, locale ...string) error { if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } @@ -1288,7 +1288,7 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron return nil // Silent success to prevent enumeration } - if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { + if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL, firstEmailLocale(locale)); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err) return nil // Silent success to prevent enumeration } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 5b7e413a..26803275 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -39,9 +39,10 @@ type AccountQuotaReader interface { // BalanceNotifyService handles balance and quota threshold notifications. type BalanceNotifyService struct { - emailService *EmailService - settingRepo SettingRepository - accountRepo AccountQuotaReader + emailService *EmailService + settingRepo SettingRepository + accountRepo AccountQuotaReader + notificationEmailService *NotificationEmailService } // NewBalanceNotifyService creates a new BalanceNotifyService. @@ -53,6 +54,10 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo } } +func (s *BalanceNotifyService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { + s.notificationEmailService = notificationEmailService +} + // resolveBalanceThreshold returns the effective balance threshold. // For percentage type, it computes threshold = totalRecharged * percentage / 100. func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 { @@ -125,7 +130,7 @@ func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user slog.Error("panic in balance notification", "recover", r) } }() - s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL) + s.sendBalanceLowEmails(recipients, user.ID, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL) }() } @@ -342,11 +347,44 @@ func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body str } // sendBalanceLowEmails sends balance low notification to all recipients. -func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) { +func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userID int64, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) { displayName := userName if displayName == "" { displayName = userEmail } + if s.notificationEmailService != nil { + fallbackRecipients := make([]string, 0, len(recipients)) + for _, to := range recipients { + ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventBalanceLow, + RecipientEmail: to, + RecipientName: displayName, + UserID: userID, + SourceType: "balance_low", + SourceID: firstNonEmpty(strconv.FormatInt(userID, 10), userEmail), + ReminderKey: time.Now().UTC().Format("2006-01-02"), + Variables: map[string]string{ + "current_balance": fmt.Sprintf("%.2f", balance), + "threshold": fmt.Sprintf("%.2f", threshold), + "recharge_url": rechargeURL, + }, + }) + cancel() + if err != nil { + if shouldFallbackNotificationEmail(err) { + slog.Warn("template balance low notification failed; falling back to built-in body", "to", to, "err", err.Error()) + fallbackRecipients = append(fallbackRecipients, to) + } else { + slog.Warn("template balance low notification delivery failed; not sending fallback to avoid duplicates", "to", to, "err", err.Error()) + } + } + } + if len(fallbackRecipients) == 0 { + return + } + recipients = fallbackRecipients + } subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName)) body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL) s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance) @@ -369,6 +407,44 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun remaining = 0 } + if s.notificationEmailService != nil { + fallbackRecipients := make([]string, 0, len(adminEmails)) + for _, to := range adminEmails { + ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventAccountQuotaAlert, + RecipientEmail: to, + RecipientName: emailRecipientName(to), + SourceType: "account_quota", + SourceID: fmt.Sprintf("%d-%s", accountID, dim.name), + ReminderKey: time.Now().UTC().Format("2006-01-02"), + Variables: map[string]string{ + "account_id": strconv.FormatInt(accountID, 10), + "account_name": accountName, + "platform": platform, + "quota_dimension": dimLabel, + "quota_used": fmt.Sprintf("%.2f", used), + "quota_limit": fmt.Sprintf("%.2f", dim.limit), + "quota_remaining": fmt.Sprintf("%.2f", remaining), + "quota_threshold": thresholdDisplay, + }, + }) + cancel() + if err != nil { + if shouldFallbackNotificationEmail(err) { + slog.Warn("template account quota alert failed; falling back to built-in body", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error()) + fallbackRecipients = append(fallbackRecipients, to) + } else { + slog.Warn("template account quota alert delivery failed; not sending fallback to avoid duplicates", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error()) + } + } + } + if len(fallbackRecipients) == 0 { + return + } + adminEmails = fallbackRecipients + } + subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName)) body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName)) s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name) diff --git a/backend/internal/service/claude_code_session_id.go b/backend/internal/service/claude_code_session_id.go index f087c004..e000108e 100644 --- a/backend/internal/service/claude_code_session_id.go +++ b/backend/internal/service/claude_code_session_id.go @@ -15,16 +15,19 @@ import ( // // 行为按 tokenType / mimicClaudeCode 分两条路径: // -// OAuth mimic 路径 (tokenType == "oauth" && mimicClaudeCode): -// 1. body 中 metadata.user_id 派生的 SessionID 是合法 UUID → canonicalize 写入 -// 2. 请求 header 中已有合法 UUID → canonicalize 保留 -// 3. 否则 → 兜底生成 UUID +// OAuth 路径 (tokenType == "oauth"): +// OAuth 账号本身就是真实 Claude Code 客户端的凭证,可以信任 body 中的 +// metadata.user_id 派生 session id。 +// 1. metadata.user_id 派生 SessionID 是合法 UUID → canonical 写入 +// 2. header 已有合法 UUID → canonical 保留 +// 3. mimicClaudeCode == true → 兜底生成新 UUID +// (mimicClaudeCode == false 且无 metadata 时不强制注入) // -// API key 透传 / 非 mimic 路径: -// - 不从 body 合成 header(避免污染客户端原始语义) -// - 但若客户端在 header 中传入了 X-Claude-Code-Session-Id: -// 合法 UUID → canonicalize 保留 -// 非法值 → 删除(不向上游转发恶意值,符合 UUID 校验承诺) +// API key 透传路径 (tokenType != "oauth"): +// - 不从 body metadata 派生 header(避免污染客户端原始语义) +// - 若客户端在 header 中传入 X-Claude-Code-Session-Id: +// 合法 UUID → canonical 保留 +// 非法值 → 删除(不向上游转发恶意值) // - 不兜底生成 // // 安全说明:metadata.user_id 由客户端控制,ParseMetadataUserID 的正则仅约束字符集, @@ -37,10 +40,10 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string, req.Header = make(http.Header) } - isOAuthMimic := tokenType == "oauth" && mimicClaudeCode + isOAuth := tokenType == "oauth" - // OAuth mimic 路径:从 metadata 派生(仅在 mimic 场景写 header)。 - if isOAuthMimic { + // OAuth 路径:从 metadata 派生(OAuth 凭证可信任)。 + if isOAuth { if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { if parsed := ParseMetadataUserID(uid); parsed != nil { if id, err := uuid.Parse(parsed.SessionID); err == nil { @@ -65,9 +68,9 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string, req.Header.Del("X-Claude-Code-Session-Id") } - // OAuth mimic 兜底生成(仅 mimic 场景;API key 不污染)。 + // OAuth mimic 兜底生成(仅 mimic 场景;API key/非 mimic 不污染)。 // uuid.NewString() 走 crypto/rand。 - if isOAuthMimic { + if isOAuth && mimicClaudeCode { setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", uuid.NewString()) } } diff --git a/backend/internal/service/claude_code_session_id_test.go b/backend/internal/service/claude_code_session_id_test.go index d5c9023b..17ec1138 100644 --- a/backend/internal/service/claude_code_session_id_test.go +++ b/backend/internal/service/claude_code_session_id_test.go @@ -136,15 +136,17 @@ func TestEnsureClaudeCodeSessionID_APIKeyIgnoresMetadata(t *testing.T) { } } -// OAuth 但非 mimic 模式也不应该从 metadata 派生 header。 -func TestEnsureClaudeCodeSessionID_OAuthNonMimicIgnoresMetadata(t *testing.T) { +// OAuth 路径即使 mimic=false 也应该从 metadata 派生 header: +// OAuth 凭证本身就是 Claude Code 类型账号,metadata.user_id 可信任。 +// 这与 API key 路径不同(API key 是任意第三方调用方)。 +func TestEnsureClaudeCodeSessionID_OAuthNonMimicDerivesFromMetadata(t *testing.T) { req := newReq(t) body := []byte(`{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"` + testValidUUID + `\"}"}}`) ensureClaudeCodeSessionID(req, body, "oauth", false) got := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id") - if got != "" { - t.Fatalf("Non-mimic OAuth must NOT derive session-id from metadata, got %q", got) + if got != testValidUUID { + t.Fatalf("OAuth must derive session-id from metadata regardless of mimic flag, got %q want %q", got, testValidUUID) } } diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go index 6a7c9904..2d066298 100644 --- a/backend/internal/service/content_moderation.go +++ b/backend/internal/service/content_moderation.go @@ -1463,6 +1463,24 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { siteName := s.siteName(ctx) + if s.emailService.notificationEmailService != nil { + if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventContentModerationViolation, + RecipientEmail: log.UserEmail, + RecipientName: emailRecipientName(log.UserEmail), + UserID: contentModerationEmailUserID(log), + SourceType: "content_moderation", + SourceID: contentModerationEmailSourceID(log), + Variables: contentModerationEmailVariables(log, cfg), + }); err == nil { + return nil + } else { + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("template content moderation violation email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error()) + } + } subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName)) body := buildContentModerationViolationEmailBody(siteName, log, cfg) return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) @@ -1470,11 +1488,71 @@ func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg * func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { siteName := s.siteName(ctx) + if s.emailService.notificationEmailService != nil { + if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventContentModerationDisabled, + RecipientEmail: log.UserEmail, + RecipientName: emailRecipientName(log.UserEmail), + UserID: contentModerationEmailUserID(log), + SourceType: "content_moderation", + SourceID: contentModerationEmailSourceID(log), + Variables: contentModerationEmailVariables(log, cfg), + }); err == nil { + return nil + } else { + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("template content moderation disabled email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error()) + } + } subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName)) body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg) return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) } +func contentModerationEmailUserID(log *ContentModerationLog) int64 { + if log == nil || log.UserID == nil { + return 0 + } + return *log.UserID +} + +func contentModerationEmailSourceID(log *ContentModerationLog) string { + if log == nil || log.ID <= 0 { + return "" + } + return fmt.Sprintf("%d", log.ID) +} + +func contentModerationEmailVariables(log *ContentModerationLog, cfg *ContentModerationConfig) map[string]string { + variables := map[string]string{ + "triggered_at": time.Now().UTC().Format(time.RFC3339), + "group_name": "-", + "moderation_category": "-", + "moderation_score": "0.000", + "violation_count": "0", + "ban_threshold": "0", + } + if log != nil { + if !log.CreatedAt.IsZero() { + variables["triggered_at"] = log.CreatedAt.UTC().Format(time.RFC3339) + } + if strings.TrimSpace(log.GroupName) != "" { + variables["group_name"] = strings.TrimSpace(log.GroupName) + } + if strings.TrimSpace(log.HighestCategory) != "" { + variables["moderation_category"] = strings.TrimSpace(log.HighestCategory) + } + variables["moderation_score"] = fmt.Sprintf("%.3f", log.HighestScore) + variables["violation_count"] = fmt.Sprintf("%d", log.ViolationCount) + } + if cfg != nil { + variables["ban_threshold"] = fmt.Sprintf("%d", cfg.BanThreshold) + } + return variables +} + func (s *ContentModerationService) siteName(ctx context.Context) string { if s == nil || s.settingRepo == nil { return "Sub2API" diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 4f2d40d8..8d591086 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -401,6 +401,10 @@ const ( SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control" // SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值) SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version" + // SettingKeyOpenAICodexUserAgent OpenAI Codex 完整 User-Agent(空值使用内置默认) + // 当客户端 UA 被识别为浏览器(Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值, + // 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。 + SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent" // Balance Low Notification SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 diff --git a/backend/internal/service/email_queue_service.go b/backend/internal/service/email_queue_service.go index d8f0a518..a933e6bb 100644 --- a/backend/internal/service/email_queue_service.go +++ b/backend/internal/service/email_queue_service.go @@ -21,6 +21,7 @@ type EmailTask struct { SiteName string TaskType string // "verify_code" or "password_reset" ResetURL string // Only used for password_reset task type + Locale string // Optional Accept-Language locale hint } // EmailQueueService 异步邮件队列服务 @@ -82,13 +83,13 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { switch task.TaskType { case TaskTypeVerifyCode: - if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { + if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName, task.Locale); err != nil { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) } else { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) } case TaskTypePasswordReset: - if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil { + if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL, task.Locale); err != nil { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) } else { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) @@ -99,11 +100,12 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { } // EnqueueVerifyCode 将验证码发送任务加入队列 -func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { +func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string, locale ...string) error { task := EmailTask{ Email: email, SiteName: siteName, TaskType: TaskTypeVerifyCode, + Locale: firstEmailLocale(locale), } select { @@ -116,12 +118,13 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { } // EnqueuePasswordReset 将密码重置邮件任务加入队列 -func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error { +func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string, locale ...string) error { task := EmailTask{ Email: email, SiteName: siteName, TaskType: TaskTypePasswordReset, ResetURL: resetURL, + Locale: firstEmailLocale(locale), } select { diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 9a03ea30..2cf42d73 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -94,8 +94,9 @@ type SMTPConfig struct { // EmailService 邮件服务 type EmailService struct { - settingRepo SettingRepository - cache EmailCache + settingRepo SettingRepository + cache EmailCache + notificationEmailService *NotificationEmailService } // NewEmailService 创建邮件服务实例 @@ -106,6 +107,28 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ } } +func (s *EmailService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { + s.notificationEmailService = notificationEmailService +} + +func firstEmailLocale(locales []string) string { + if len(locales) == 0 { + return "" + } + return strings.TrimSpace(locales[0]) +} + +func emailRecipientName(email string) string { + trimmed := strings.TrimSpace(email) + if trimmed == "" { + return "" + } + if at := strings.Index(trimmed, "@"); at > 0 { + return trimmed[:at] + } + return trimmed +} + // GetSMTPConfig 从数据库获取SMTP配置 func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { keys := []string{ @@ -301,7 +324,7 @@ func (s *EmailService) GenerateVerifyCode() (string, error) { } // SendVerifyCode 发送验证码邮件 -func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error { +func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string, locale ...string) error { // 检查是否在冷却期内 existing, err := s.cache.GetVerificationCode(ctx, email) if err == nil && existing != nil { @@ -327,6 +350,26 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin return fmt.Errorf("save verify code: %w", err) } + if s.notificationEmailService != nil { + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventAuthVerifyCode, + Locale: firstEmailLocale(locale), + RecipientEmail: email, + RecipientName: emailRecipientName(email), + Variables: map[string]string{ + "verification_code": code, + "expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)), + }, + }) + if err == nil { + return nil + } + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("failed to send templated verification email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err) + } + // 构建邮件内容 subject := fmt.Sprintf("[%s] Email Verification Code", siteName) body := s.buildVerifyCodeEmailBody(code, siteName) @@ -469,7 +512,7 @@ func (s *EmailService) GeneratePasswordResetToken() (string, error) { } // SendPasswordResetEmail sends a password reset email with a reset link -func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error { +func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string, locale ...string) error { var token string var needSaveToken bool @@ -502,6 +545,26 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa // Build full reset URL with URL-encoded token and email fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token)) + if s.notificationEmailService != nil { + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventAuthPasswordReset, + Locale: firstEmailLocale(locale), + RecipientEmail: email, + RecipientName: emailRecipientName(email), + Variables: map[string]string{ + "reset_url": fullResetURL, + "expires_in_minutes": strconv.Itoa(int(passwordResetTokenTTL / time.Minute)), + }, + }) + if err == nil { + return nil + } + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("failed to send templated password reset email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err) + } + // Build email content subject := fmt.Sprintf("[%s] 密码重置请求", siteName) body := s.buildPasswordResetEmailBody(fullResetURL, siteName) @@ -516,7 +579,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa // SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker) // This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing -func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { +func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string, locale ...string) error { // Check email cooldown to prevent email bombing if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { slog.Info("password reset email skipped due to cooldown", "email", email) @@ -524,7 +587,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e } // Send email using core method - if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { + if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil { return err } diff --git a/backend/internal/service/notification_email_service.go b/backend/internal/service/notification_email_service.go new file mode 100644 index 00000000..d21363b6 --- /dev/null +++ b/backend/internal/service/notification_email_service.go @@ -0,0 +1,1347 @@ +package service + +import ( + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "html" + "log/slog" + "net/url" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + NotificationEmailEventAuthVerifyCode = "auth.verify_code" + NotificationEmailEventAuthPasswordReset = "auth.password_reset" + NotificationEmailEventNotificationEmailVerifyCode = "notification_email.verify_code" + NotificationEmailEventSubscriptionPurchaseSuccess = "subscription.purchase_success" + NotificationEmailEventSubscriptionExpiryReminder = "subscription.expiry_reminder" + NotificationEmailEventBalanceLow = "balance.low" + NotificationEmailEventBalanceRechargeSuccess = "balance.recharge_success" + NotificationEmailEventAccountQuotaAlert = "account.quota_alert" + NotificationEmailEventContentModerationViolation = "content_moderation.violation_notice" + NotificationEmailEventContentModerationDisabled = "content_moderation.account_disabled" + NotificationEmailEventOpsAlert = "ops.alert" + NotificationEmailEventOpsScheduledReport = "ops.scheduled_report" + + notificationEmailTemplateKeyPrefix = "notification_email_template:" + notificationEmailPreferenceKeyPrefix = "notification_email_preference:" + notificationEmailDeliveryKeyPrefix = "notification_email_delivery:" + notificationEmailLocaleUserKeyPrefix = "notification_email_locale:user:" + notificationEmailLocaleEmailKeyPrefix = "notification_email_locale:email:" + notificationEmailUnsubscribeSecretKey = "notification_email_unsubscribe_secret" + notificationEmailDefaultLocale = "en" + notificationEmailLocaleChinese = "zh" + notificationEmailMaxSubjectLength = 200 + notificationEmailMaxHTMLLength = 30000 + notificationEmailUnsubscribeTTL = 365 * 24 * time.Hour +) + +var ( + notificationEmailPlaceholderPattern = regexp.MustCompile(`{{\s*([a-zA-Z][a-zA-Z0-9_]*)\s*}}`) + notificationEmailLocales = []string{notificationEmailDefaultLocale, notificationEmailLocaleChinese} + notificationEmailCommonPlaceholders = []string{"site_name", "recipient_name", "recipient_email"} +) + +type NotificationEmailService struct { + settingRepo SettingRepository + emailService *EmailService +} + +type NotificationEmailEventInfo struct { + Event string `json:"event"` + Label string `json:"label"` + Description string `json:"description"` + Category string `json:"category"` + Optional bool `json:"optional"` + Placeholders []string `json:"placeholders"` +} + +type NotificationEmailTemplate struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + IsCustom bool `json:"is_custom"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + Placeholders []string `json:"placeholders"` +} + +type NotificationEmailPreview struct { + Subject string `json:"subject"` + HTML string `json:"html"` +} + +type NotificationEmailPreviewInput struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + Variables map[string]string `json:"variables,omitempty"` +} + +type NotificationEmailSendInput struct { + Event string + Locale string + RecipientEmail string + RecipientName string + UserID int64 + SourceType string + SourceID string + ReminderKey string + Variables map[string]string + RawHTMLVariables map[string]string +} + +type NotificationEmailUnsubscribeResult struct { + Event string `json:"event"` + Email string `json:"email"` + Done bool `json:"done"` +} + +type notificationEmailStoredTemplate struct { + Subject string `json:"subject"` + HTML string `json:"html"` + UpdatedAt time.Time `json:"updated_at"` +} + +type notificationEmailOfficialTemplate struct { + Subject string + HTML string +} + +type notificationEmailTemplateError struct { + Err error +} + +func (e notificationEmailTemplateError) Error() string { + return e.Err.Error() +} + +func (e notificationEmailTemplateError) Unwrap() error { + return e.Err +} + +type notificationEmailConfigError struct { + Err error +} + +func (e notificationEmailConfigError) Error() string { + return e.Err.Error() +} + +func (e notificationEmailConfigError) Unwrap() error { + return e.Err +} + +type notificationEmailDeliveryError struct { + Err error +} + +func (e notificationEmailDeliveryError) Error() string { + return e.Err.Error() +} + +func (e notificationEmailDeliveryError) Unwrap() error { + return e.Err +} + +type notificationEmailUnsubscribeClaims struct { + Email string `json:"email"` + Event string `json:"event"` + Exp int64 `json:"exp"` +} + +func NewNotificationEmailService(settingRepo SettingRepository, emailService *EmailService) *NotificationEmailService { + svc := &NotificationEmailService{settingRepo: settingRepo, emailService: emailService} + if emailService != nil { + emailService.SetNotificationEmailService(svc) + } + return svc +} + +func notificationEmailTemplateErr(err error) error { + if err == nil { + return nil + } + return notificationEmailTemplateError{Err: err} +} + +func notificationEmailConfigErr(err error) error { + if err == nil { + return nil + } + return notificationEmailConfigError{Err: err} +} + +func notificationEmailDeliveryErr(err error) error { + if err == nil { + return nil + } + return notificationEmailDeliveryError{Err: err} +} + +func shouldFallbackNotificationEmail(err error) bool { + if err == nil { + return false + } + var templateErr notificationEmailTemplateError + if errors.As(err, &templateErr) { + return true + } + var configErr notificationEmailConfigError + return errors.As(err, &configErr) +} + +func isNotificationEmailDeliveryError(err error) bool { + var deliveryErr notificationEmailDeliveryError + return errors.As(err, &deliveryErr) +} + +func (s *NotificationEmailService) ListEventInfos() []NotificationEmailEventInfo { + infos := make([]NotificationEmailEventInfo, 0, len(notificationEmailEventDefinitions)) + for _, event := range notificationEmailEventOrder { + info := notificationEmailEventDefinitions[event] + info.Placeholders = append([]string(nil), info.Placeholders...) + infos = append(infos, info) + } + return infos +} + +func (s *NotificationEmailService) SupportedLocales() []string { + return append([]string(nil), notificationEmailLocales...) +} + +func (s *NotificationEmailService) ListTemplates(ctx context.Context) ([]NotificationEmailTemplate, error) { + items := make([]NotificationEmailTemplate, 0, len(notificationEmailEventOrder)*len(notificationEmailLocales)) + for _, event := range notificationEmailEventOrder { + for _, locale := range notificationEmailLocales { + tmpl, err := s.GetTemplate(ctx, event, locale) + if err != nil { + return nil, err + } + items = append(items, tmpl) + } + } + return items, nil +} + +func (s *NotificationEmailService) GetTemplate(ctx context.Context, event, locale string) (NotificationEmailTemplate, error) { + info, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return NotificationEmailTemplate{}, err + } + normalizedLocale := normalizeNotificationLocale(locale) + official, ok := notificationEmailOfficialTemplates[normalizedEvent][normalizedLocale] + if !ok { + return NotificationEmailTemplate{}, fmt.Errorf("official template not found for %s/%s", normalizedEvent, normalizedLocale) + } + + tmpl := NotificationEmailTemplate{ + Event: normalizedEvent, + Locale: normalizedLocale, + Subject: official.Subject, + HTML: official.HTML, + Placeholders: append([]string(nil), info.Placeholders...), + } + + raw, err := s.settingRepo.GetValue(ctx, notificationEmailTemplateKey(normalizedEvent, normalizedLocale)) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return tmpl, nil + } + return NotificationEmailTemplate{}, err + } + if strings.TrimSpace(raw) == "" { + return tmpl, nil + } + + var stored notificationEmailStoredTemplate + if err := json.Unmarshal([]byte(raw), &stored); err != nil { + return NotificationEmailTemplate{}, fmt.Errorf("decode email template override: %w", err) + } + if err := validateNotificationEmailTemplate(normalizedEvent, stored.Subject, stored.HTML); err != nil { + return NotificationEmailTemplate{}, err + } + tmpl.Subject = stored.Subject + tmpl.HTML = stored.HTML + tmpl.IsCustom = true + updatedAt := stored.UpdatedAt + tmpl.UpdatedAt = &updatedAt + return tmpl, nil +} + +func (s *NotificationEmailService) UpdateTemplate(ctx context.Context, event, locale, subject, htmlBody string) (NotificationEmailTemplate, error) { + _, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return NotificationEmailTemplate{}, err + } + normalizedLocale := normalizeNotificationLocale(locale) + if err := validateNotificationEmailTemplate(normalizedEvent, subject, htmlBody); err != nil { + return NotificationEmailTemplate{}, err + } + stored := notificationEmailStoredTemplate{ + Subject: strings.TrimSpace(subject), + HTML: htmlBody, + UpdatedAt: time.Now().UTC(), + } + payload, err := json.Marshal(stored) + if err != nil { + return NotificationEmailTemplate{}, err + } + if err := s.settingRepo.Set(ctx, notificationEmailTemplateKey(normalizedEvent, normalizedLocale), string(payload)); err != nil { + return NotificationEmailTemplate{}, err + } + return s.GetTemplate(ctx, normalizedEvent, normalizedLocale) +} + +func (s *NotificationEmailService) RestoreOfficialTemplate(ctx context.Context, event, locale string) (NotificationEmailTemplate, error) { + _, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return NotificationEmailTemplate{}, err + } + normalizedLocale := normalizeNotificationLocale(locale) + if err := s.settingRepo.Delete(ctx, notificationEmailTemplateKey(normalizedEvent, normalizedLocale)); err != nil && !errors.Is(err, ErrSettingNotFound) { + return NotificationEmailTemplate{}, err + } + return s.GetTemplate(ctx, normalizedEvent, normalizedLocale) +} + +func (s *NotificationEmailService) PreviewTemplate(ctx context.Context, input NotificationEmailPreviewInput) (NotificationEmailPreview, error) { + _, normalizedEvent, err := s.eventInfo(input.Event) + if err != nil { + return NotificationEmailPreview{}, err + } + normalizedLocale := normalizeNotificationLocale(input.Locale) + subject := input.Subject + htmlBody := input.HTML + if strings.TrimSpace(subject) == "" || strings.TrimSpace(htmlBody) == "" { + tmpl, err := s.GetTemplate(ctx, normalizedEvent, normalizedLocale) + if err != nil { + return NotificationEmailPreview{}, err + } + if strings.TrimSpace(subject) == "" { + subject = tmpl.Subject + } + if strings.TrimSpace(htmlBody) == "" { + htmlBody = tmpl.HTML + } + } + if err := validateNotificationEmailTemplate(normalizedEvent, subject, htmlBody); err != nil { + return NotificationEmailPreview{}, err + } + variables := s.sampleVariables(ctx, normalizedEvent, normalizedLocale) + for key, value := range input.Variables { + variables[key] = value + } + return renderNotificationEmail(normalizedEvent, subject, htmlBody, variables, nil) +} + +func (s *NotificationEmailService) Send(ctx context.Context, input NotificationEmailSendInput) error { + info, normalizedEvent, err := s.eventInfo(input.Event) + if err != nil { + return notificationEmailTemplateErr(err) + } + recipient := strings.TrimSpace(input.RecipientEmail) + if recipient == "" { + return nil + } + if info.Optional { + unsubscribed, err := s.IsUnsubscribed(ctx, recipient, normalizedEvent) + if err != nil { + return err + } + if unsubscribed { + slog.Info("notification email suppressed by unsubscribe preference", "event", normalizedEvent, "recipient_hash", notificationEmailHash(recipient)) + return nil + } + } + + locale := normalizeNotificationLocale(input.Locale) + if strings.TrimSpace(input.Locale) == "" { + locale = s.ResolveRecipientLocale(ctx, input.UserID, recipient) + } + tmpl, err := s.GetTemplate(ctx, normalizedEvent, locale) + if err != nil { + return notificationEmailTemplateErr(err) + } + variables := s.runtimeVariables(ctx, normalizedEvent, locale, input) + rendered, err := renderNotificationEmail(normalizedEvent, tmpl.Subject, tmpl.HTML, variables, input.RawHTMLVariables) + if err != nil { + return notificationEmailTemplateErr(err) + } + + deliveryKey := notificationEmailDeliveryKey(normalizedEvent, input.SourceType, input.SourceID, recipient, input.ReminderKey) + if deliveryKey != "" { + sent, err := s.deliveryExists(ctx, deliveryKey, legacyNotificationEmailDeliveryKey(normalizedEvent, input.SourceType, input.SourceID, recipient, input.ReminderKey)) + if err != nil { + return err + } + if sent { + return nil + } + } + + if s.emailService == nil { + return notificationEmailConfigErr(errors.New("email service is not configured")) + } + if err := s.emailService.SendEmail(ctx, recipient, rendered.Subject, rendered.HTML); err != nil { + return notificationEmailDeliveryErr(err) + } + if deliveryKey != "" { + if err := s.settingRepo.Set(ctx, deliveryKey, time.Now().UTC().Format(time.RFC3339Nano)); err != nil { + return err + } + } + return nil +} + +func (s *NotificationEmailService) RememberRecipientLocale(ctx context.Context, userID int64, email, acceptLanguage string) { + locale := normalizeNotificationLocale(acceptLanguage) + if strings.TrimSpace(acceptLanguage) == "" || s == nil || s.settingRepo == nil { + return + } + if userID > 0 { + _ = s.settingRepo.Set(ctx, notificationEmailLocaleUserKeyPrefix+strconv.FormatInt(userID, 10), locale) + } + if emailHash := notificationEmailHash(email); emailHash != "" { + _ = s.settingRepo.Set(ctx, notificationEmailLocaleEmailKeyPrefix+emailHash, locale) + } +} + +func (s *NotificationEmailService) ResolveRecipientLocale(ctx context.Context, userID int64, email string) string { + if s == nil || s.settingRepo == nil { + return notificationEmailDefaultLocale + } + if userID > 0 { + if locale, err := s.settingRepo.GetValue(ctx, notificationEmailLocaleUserKeyPrefix+strconv.FormatInt(userID, 10)); err == nil && strings.TrimSpace(locale) != "" { + return normalizeNotificationLocale(locale) + } + } + if emailHash := notificationEmailHash(email); emailHash != "" { + if locale, err := s.settingRepo.GetValue(ctx, notificationEmailLocaleEmailKeyPrefix+emailHash); err == nil && strings.TrimSpace(locale) != "" { + return normalizeNotificationLocale(locale) + } + } + return notificationEmailDefaultLocale +} + +func (s *NotificationEmailService) IsUnsubscribed(ctx context.Context, email, event string) (bool, error) { + info, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return false, err + } + if !info.Optional { + return false, nil + } + for _, key := range []string{notificationEmailPreferenceKey(normalizedEvent, email), legacyNotificationEmailPreferenceKey(normalizedEvent, email)} { + if strings.TrimSpace(key) == "" { + continue + } + value, err := s.settingRepo.GetValue(ctx, key) + if err == nil { + return strings.EqualFold(strings.TrimSpace(value), "unsubscribed"), nil + } + if !errors.Is(err, ErrSettingNotFound) { + return false, err + } + } + return false, nil +} + +func (s *NotificationEmailService) Unsubscribe(ctx context.Context, token string) (NotificationEmailUnsubscribeResult, error) { + claims, err := s.parseUnsubscribeToken(ctx, token) + if err != nil { + return NotificationEmailUnsubscribeResult{}, err + } + info, normalizedEvent, err := s.eventInfo(claims.Event) + if err != nil { + return NotificationEmailUnsubscribeResult{}, err + } + if !info.Optional { + return NotificationEmailUnsubscribeResult{}, fmt.Errorf("%s is transactional and cannot be unsubscribed", normalizedEvent) + } + if err := s.settingRepo.Set(ctx, notificationEmailPreferenceKey(normalizedEvent, claims.Email), "unsubscribed"); err != nil { + return NotificationEmailUnsubscribeResult{}, err + } + return NotificationEmailUnsubscribeResult{Event: normalizedEvent, Email: claims.Email, Done: true}, nil +} + +func (s *NotificationEmailService) eventInfo(event string) (NotificationEmailEventInfo, string, error) { + normalized := strings.ToLower(strings.TrimSpace(event)) + info, ok := notificationEmailEventDefinitions[normalized] + if !ok { + return NotificationEmailEventInfo{}, "", fmt.Errorf("unsupported email template event: %s", event) + } + return info, normalized, nil +} + +func (s *NotificationEmailService) sampleVariables(ctx context.Context, event, locale string) map[string]string { + info := notificationEmailEventDefinitions[event] + variables := make(map[string]string, len(info.Placeholders)) + for key, value := range notificationEmailSampleVariables(locale) { + variables[key] = value + } + variables["site_name"] = s.siteName(ctx) + if variables["unsubscribe_url"] == "" && info.Optional { + variables["unsubscribe_url"] = "https://example.com/unsubscribe" + } + return variables +} + +func (s *NotificationEmailService) runtimeVariables(ctx context.Context, event, locale string, input NotificationEmailSendInput) map[string]string { + variables := s.sampleVariables(ctx, event, locale) + for key, value := range input.Variables { + variables[key] = value + } + variables["site_name"] = s.siteName(ctx) + variables["recipient_email"] = input.RecipientEmail + if strings.TrimSpace(input.RecipientName) != "" { + variables["recipient_name"] = input.RecipientName + } + if notificationEmailEventDefinitions[event].Optional { + if unsubscribeURL, err := s.buildUnsubscribeURL(ctx, input.RecipientEmail, event); err == nil { + variables["unsubscribe_url"] = unsubscribeURL + } + } + return variables +} + +func (s *NotificationEmailService) siteName(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return defaultSiteName + } + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || strings.TrimSpace(name) == "" { + return defaultSiteName + } + return strings.TrimSpace(name) +} + +func (s *NotificationEmailService) baseURL(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return "" + } + for _, key := range []string{SettingKeyAPIBaseURL, SettingKeyFrontendURL} { + value, err := s.settingRepo.GetValue(ctx, key) + if err == nil && strings.TrimSpace(value) != "" { + return strings.TrimRight(strings.TrimSpace(value), "/") + } + } + return "" +} + +func (s *NotificationEmailService) buildUnsubscribeURL(ctx context.Context, email, event string) (string, error) { + token, err := s.createUnsubscribeToken(ctx, email, event) + if err != nil { + return "", err + } + path := "/api/v1/settings/email-unsubscribe?token=" + url.QueryEscape(token) + baseURL := s.baseURL(ctx) + if baseURL == "" { + return path, nil + } + return baseURL + path, nil +} + +func (s *NotificationEmailService) createUnsubscribeToken(ctx context.Context, email, event string) (string, error) { + secret, err := s.unsubscribeSecret(ctx) + if err != nil { + return "", err + } + claims := notificationEmailUnsubscribeClaims{Email: strings.TrimSpace(email), Event: event, Exp: time.Now().Add(notificationEmailUnsubscribeTTL).Unix()} + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + signature := signNotificationEmailToken(secret, encodedPayload) + return encodedPayload + "." + signature, nil +} + +func (s *NotificationEmailService) parseUnsubscribeToken(ctx context.Context, token string) (notificationEmailUnsubscribeClaims, error) { + parts := strings.Split(strings.TrimSpace(token), ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token") + } + secret, err := s.unsubscribeSecret(ctx) + if err != nil { + return notificationEmailUnsubscribeClaims{}, err + } + expected := signNotificationEmailToken(secret, parts[0]) + if !hmac.Equal([]byte(expected), []byte(parts[1])) { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token signature") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token payload") + } + var claims notificationEmailUnsubscribeClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token payload") + } + if strings.TrimSpace(claims.Email) == "" || strings.TrimSpace(claims.Event) == "" { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token claims") + } + if claims.Exp <= time.Now().Unix() { + return notificationEmailUnsubscribeClaims{}, errors.New("unsubscribe token expired") + } + return claims, nil +} + +func (s *NotificationEmailService) unsubscribeSecret(ctx context.Context) (string, error) { + secret, err := s.settingRepo.GetValue(ctx, notificationEmailUnsubscribeSecretKey) + if err == nil && strings.TrimSpace(secret) != "" { + return strings.TrimSpace(secret), nil + } + if err != nil && !errors.Is(err, ErrSettingNotFound) { + return "", err + } + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + secret = base64.RawURLEncoding.EncodeToString(buf) + if err := s.settingRepo.Set(ctx, notificationEmailUnsubscribeSecretKey, secret); err != nil { + return "", err + } + return secret, nil +} + +func (s *NotificationEmailService) deliveryExists(ctx context.Context, keys ...string) (bool, error) { + for _, key := range keys { + if strings.TrimSpace(key) == "" { + continue + } + _, err := s.settingRepo.GetValue(ctx, key) + if err == nil { + return true, nil + } + if !errors.Is(err, ErrSettingNotFound) { + return false, err + } + } + return false, nil +} + +func validateNotificationEmailTemplate(event, subject, htmlBody string) error { + subject = strings.TrimSpace(subject) + if subject == "" { + return errors.New("email subject cannot be empty") + } + if len([]rune(subject)) > notificationEmailMaxSubjectLength { + return fmt.Errorf("email subject cannot exceed %d characters", notificationEmailMaxSubjectLength) + } + if strings.TrimSpace(htmlBody) == "" { + return errors.New("email html cannot be empty") + } + if len([]byte(htmlBody)) > notificationEmailMaxHTMLLength { + return fmt.Errorf("email html cannot exceed %d bytes", notificationEmailMaxHTMLLength) + } + allowed := notificationEmailAllowedPlaceholderSet(event) + for _, placeholder := range notificationEmailPlaceholdersIn(subject + "\n" + htmlBody) { + if _, ok := allowed[placeholder]; !ok { + return fmt.Errorf("unsupported placeholder {{%s}} for event %s", placeholder, event) + } + } + return nil +} + +func renderNotificationEmail(event, subject, htmlBody string, variables map[string]string, rawHTMLVariables map[string]string) (NotificationEmailPreview, error) { + if err := validateNotificationEmailTemplate(event, subject, htmlBody); err != nil { + return NotificationEmailPreview{}, err + } + renderedSubject, err := renderNotificationEmailString(event, subject, variables, nil, false) + if err != nil { + return NotificationEmailPreview{}, err + } + renderedHTML, err := renderNotificationEmailString(event, htmlBody, variables, rawHTMLVariables, true) + if err != nil { + return NotificationEmailPreview{}, err + } + return NotificationEmailPreview{Subject: sanitizeEmailHeader(renderedSubject), HTML: renderedHTML}, nil +} + +func renderNotificationEmailString(event, raw string, variables map[string]string, rawHTMLVariables map[string]string, escapeHTML bool) (string, error) { + allowed := notificationEmailAllowedPlaceholderSet(event) + var renderErr error + rendered := notificationEmailPlaceholderPattern.ReplaceAllStringFunc(raw, func(match string) string { + if renderErr != nil { + return "" + } + parts := notificationEmailPlaceholderPattern.FindStringSubmatch(match) + if len(parts) != 2 { + return "" + } + name := parts[1] + if _, ok := allowed[name]; !ok { + renderErr = fmt.Errorf("unsupported placeholder {{%s}} for event %s", name, event) + return "" + } + value := variables[name] + if escapeHTML && notificationEmailRawHTMLAllowed(event, name) { + if rawHTMLVariables != nil { + if rawValue, ok := rawHTMLVariables[name]; ok { + return rawValue + } + } + } + if strings.HasSuffix(name, "_url") && !isSafeNotificationEmailURL(value) { + value = "" + } + if escapeHTML { + return html.EscapeString(value) + } + return sanitizeEmailHeader(value) + }) + if renderErr != nil { + return "", renderErr + } + return rendered, nil +} + +func notificationEmailRawHTMLAllowed(event, placeholder string) bool { + return event == NotificationEmailEventOpsScheduledReport && placeholder == "report_html" +} + +func notificationEmailAllowedPlaceholderSet(event string) map[string]struct{} { + info := notificationEmailEventDefinitions[event] + allowed := make(map[string]struct{}, len(info.Placeholders)) + for _, placeholder := range info.Placeholders { + allowed[placeholder] = struct{}{} + } + return allowed +} + +func notificationEmailPlaceholdersIn(raw string) []string { + matches := notificationEmailPlaceholderPattern.FindAllStringSubmatch(raw, -1) + seen := make(map[string]struct{}, len(matches)) + out := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) != 2 { + continue + } + if _, exists := seen[match[1]]; exists { + continue + } + seen[match[1]] = struct{}{} + out = append(out, match[1]) + } + return out +} + +func normalizeNotificationLocale(raw string) string { + trimmed := strings.ToLower(strings.TrimSpace(raw)) + if trimmed == "" { + return notificationEmailDefaultLocale + } + for _, part := range strings.Split(trimmed, ",") { + tag := strings.TrimSpace(strings.Split(part, ";")[0]) + if strings.HasPrefix(tag, "zh") || tag == "cn" { + return notificationEmailLocaleChinese + } + if strings.HasPrefix(tag, "en") { + return notificationEmailDefaultLocale + } + } + return notificationEmailDefaultLocale +} + +func notificationEmailTemplateKey(event, locale string) string { + return notificationEmailTemplateKeyPrefix + event + ":" + locale +} + +func notificationEmailPreferenceKey(event, email string) string { + if strings.TrimSpace(event) == "" || strings.TrimSpace(email) == "" { + return "" + } + identity := strings.TrimSpace(event) + "\x00" + strings.ToLower(strings.TrimSpace(email)) + return notificationEmailPreferenceKeyPrefix + "v2:" + notificationEmailHash(identity) +} + +func legacyNotificationEmailPreferenceKey(event, email string) string { + return notificationEmailPreferenceKeyPrefix + event + ":" + notificationEmailHash(email) +} + +func notificationEmailDeliveryKey(event, sourceType, sourceID, recipient, reminderKey string) string { + if strings.TrimSpace(sourceType) == "" || strings.TrimSpace(sourceID) == "" || strings.TrimSpace(recipient) == "" { + return "" + } + identity := strings.Join([]string{ + strings.ToLower(strings.TrimSpace(event)), + safeNotificationEmailKeyPart(sourceType), + safeNotificationEmailKeyPart(sourceID), + strings.ToLower(strings.TrimSpace(recipient)), + safeNotificationEmailKeyPart(reminderKey), + }, "\x00") + return notificationEmailDeliveryKeyPrefix + "v2:" + notificationEmailHash(identity) +} + +func legacyNotificationEmailDeliveryKey(event, sourceType, sourceID, recipient, reminderKey string) string { + if strings.TrimSpace(sourceType) == "" || strings.TrimSpace(sourceID) == "" || strings.TrimSpace(recipient) == "" { + return "" + } + parts := []string{notificationEmailDeliveryKeyPrefix, event, ":", safeNotificationEmailKeyPart(sourceType), ":", safeNotificationEmailKeyPart(sourceID), ":", notificationEmailHash(recipient)} + if strings.TrimSpace(reminderKey) != "" { + parts = append(parts, ":", safeNotificationEmailKeyPart(reminderKey)) + } + return strings.Join(parts, "") +} + +func notificationEmailHash(value string) string { + trimmed := strings.ToLower(strings.TrimSpace(value)) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return hex.EncodeToString(sum[:]) +} + +func safeNotificationEmailKeyPart(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + var builder strings.Builder + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' || r == '.' { + _, _ = builder.WriteRune(r) + } else { + _, _ = builder.WriteRune('_') + } + } + return builder.String() +} + +func signNotificationEmailToken(secret, payload string) string { + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} + +func isSafeNotificationEmailURL(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return true + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + if parsed.IsAbs() { + scheme := strings.ToLower(parsed.Scheme) + return scheme == "http" || scheme == "https" || scheme == "mailto" + } + return strings.HasPrefix(trimmed, "/") +} + +func notificationEmailSampleVariables(locale string) map[string]string { + if normalizeNotificationLocale(locale) == notificationEmailLocaleChinese { + return map[string]string{ + "site_name": defaultSiteName, + "recipient_name": "张三", + "recipient_email": "user@example.com", + "verification_code": "123456", + "expires_in_minutes": "15", + "reset_url": "https://example.com/reset-password?token=preview", + "subscription_group": "Claude Pro", + "subscription_days": "30", + "expiry_time": "2026-06-18 12:00", + "days_remaining": "3", + "current_balance": "12.34", + "threshold": "20.00", + "recharge_url": "https://example.com/recharge", + "recharge_amount": "50.00", + "order_id": "1024", + "unsubscribe_url": "https://example.com/unsubscribe", + "account_id": "1001", + "account_name": "openai-main", + "platform": "openai", + "quota_dimension": "每日额度", + "quota_used": "80.00", + "quota_limit": "100.00", + "quota_remaining": "20.00", + "quota_threshold": "20%", + "triggered_at": "2026-05-20 12:00:00", + "group_name": "默认分组", + "moderation_category": "violence", + "moderation_score": "0.982", + "violation_count": "2", + "ban_threshold": "3", + "rule_name": "错误率过高", + "severity": "critical", + "alert_status": "firing", + "metric_type": "error_rate", + "operator": ">=", + "metric_value": "12.50", + "threshold_value": "10.00", + "alert_description": "最近 10 分钟错误率超过阈值", + "report_name": "日报", + "report_type": "daily_summary", + "report_start_time": "2026-05-19 12:00", + "report_end_time": "2026-05-20 12:00", + "report_html": "请求量:1024
", + } + } + return map[string]string{ + "site_name": defaultSiteName, + "recipient_name": "Alex", + "recipient_email": "user@example.com", + "verification_code": "123456", + "expires_in_minutes": "15", + "reset_url": "https://example.com/reset-password?token=preview", + "subscription_group": "Claude Pro", + "subscription_days": "30", + "expiry_time": "2026-06-18 12:00", + "days_remaining": "3", + "current_balance": "12.34", + "threshold": "20.00", + "recharge_url": "https://example.com/recharge", + "recharge_amount": "50.00", + "order_id": "1024", + "unsubscribe_url": "https://example.com/unsubscribe", + "account_id": "1001", + "account_name": "openai-main", + "platform": "openai", + "quota_dimension": "Daily quota", + "quota_used": "80.00", + "quota_limit": "100.00", + "quota_remaining": "20.00", + "quota_threshold": "20%", + "triggered_at": "2026-05-20 12:00:00", + "group_name": "Default group", + "moderation_category": "violence", + "moderation_score": "0.982", + "violation_count": "2", + "ban_threshold": "3", + "rule_name": "High error rate", + "severity": "critical", + "alert_status": "firing", + "metric_type": "error_rate", + "operator": ">=", + "metric_value": "12.50", + "threshold_value": "10.00", + "alert_description": "Error rate exceeded threshold in the last 10 minutes.", + "report_name": "Daily summary", + "report_type": "daily_summary", + "report_start_time": "2026-05-19 12:00", + "report_end_time": "2026-05-20 12:00", + "report_html": "Requests: 1024
", + } +} + +var notificationEmailEventOrder = []string{ + NotificationEmailEventAuthVerifyCode, + NotificationEmailEventAuthPasswordReset, + NotificationEmailEventNotificationEmailVerifyCode, + NotificationEmailEventSubscriptionPurchaseSuccess, + NotificationEmailEventSubscriptionExpiryReminder, + NotificationEmailEventBalanceLow, + NotificationEmailEventBalanceRechargeSuccess, + NotificationEmailEventAccountQuotaAlert, + NotificationEmailEventContentModerationViolation, + NotificationEmailEventContentModerationDisabled, + NotificationEmailEventOpsAlert, + NotificationEmailEventOpsScheduledReport, +} + +var notificationEmailEventDefinitions = map[string]NotificationEmailEventInfo{ + NotificationEmailEventAuthVerifyCode: { + Event: NotificationEmailEventAuthVerifyCode, + Label: "Email verification code", + Description: "Sent for registration, email binding, OAuth pending email, and TOTP verification flows.", + Category: "auth", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "verification_code", "expires_in_minutes"), + }, + NotificationEmailEventAuthPasswordReset: { + Event: NotificationEmailEventAuthPasswordReset, + Label: "Password reset", + Description: "Sent when a user requests a password reset link.", + Category: "auth", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "reset_url", "expires_in_minutes"), + }, + NotificationEmailEventNotificationEmailVerifyCode: { + Event: NotificationEmailEventNotificationEmailVerifyCode, + Label: "Notification email verification code", + Description: "Sent when a user verifies an extra notification email address.", + Category: "auth", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "verification_code", "expires_in_minutes"), + }, + NotificationEmailEventSubscriptionPurchaseSuccess: { + Event: NotificationEmailEventSubscriptionPurchaseSuccess, + Label: "Subscription purchase success", + Description: "Sent after a subscription purchase is fulfilled.", + Category: "subscription", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "subscription_group", "subscription_days", "expiry_time", "order_id"), + }, + NotificationEmailEventSubscriptionExpiryReminder: { + Event: NotificationEmailEventSubscriptionExpiryReminder, + Label: "Subscription expiry reminder", + Description: "Optional reminder sent before an active subscription expires.", + Category: "subscription", + Optional: true, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "subscription_group", "expiry_time", "days_remaining", "unsubscribe_url"), + }, + NotificationEmailEventBalanceLow: { + Event: NotificationEmailEventBalanceLow, + Label: "Low balance alert", + Description: "Optional alert sent when balance crosses the configured low-balance threshold.", + Category: "billing", + Optional: true, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "current_balance", "threshold", "recharge_url", "unsubscribe_url"), + }, + NotificationEmailEventBalanceRechargeSuccess: { + Event: NotificationEmailEventBalanceRechargeSuccess, + Label: "Balance recharge success", + Description: "Sent after a balance recharge order is fulfilled.", + Category: "billing", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "recharge_amount", "current_balance", "order_id"), + }, + NotificationEmailEventAccountQuotaAlert: { + Event: NotificationEmailEventAccountQuotaAlert, + Label: "Account quota alert", + Description: "Sent to configured admin notification emails when an upstream account quota threshold is crossed.", + Category: "admin", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "account_id", "account_name", "platform", "quota_dimension", "quota_used", "quota_limit", "quota_remaining", "quota_threshold"), + }, + NotificationEmailEventContentModerationViolation: { + Event: NotificationEmailEventContentModerationViolation, + Label: "Risk control violation notice", + Description: "Sent to users when a request triggers content moderation/risk control rules.", + Category: "risk_control", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "triggered_at", "group_name", "moderation_category", "moderation_score", "violation_count", "ban_threshold"), + }, + NotificationEmailEventContentModerationDisabled: { + Event: NotificationEmailEventContentModerationDisabled, + Label: "Risk control account disabled", + Description: "Sent to users when content moderation automatically disables their account.", + Category: "risk_control", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "triggered_at", "group_name", "moderation_category", "moderation_score", "violation_count", "ban_threshold"), + }, + NotificationEmailEventOpsAlert: { + Event: NotificationEmailEventOpsAlert, + Label: "Ops alert", + Description: "Sent to configured operations recipients when an ops alert rule fires.", + Category: "ops", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "rule_name", "severity", "alert_status", "metric_type", "operator", "metric_value", "threshold_value", "triggered_at", "alert_description"), + }, + NotificationEmailEventOpsScheduledReport: { + Event: NotificationEmailEventOpsScheduledReport, + Label: "Ops scheduled report", + Description: "Sent to configured operations recipients for scheduled daily/weekly/error/account-health reports.", + Category: "ops", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "report_name", "report_type", "report_start_time", "report_end_time", "report_html"), + }, +} + +var notificationEmailOfficialTemplates = map[string]map[string]notificationEmailOfficialTemplate{ + NotificationEmailEventAuthVerifyCode: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Email verification code", + HTML: notificationEmailCard("#4f46e5", "Email verification code", ` +Hello {{recipient_name}},
+Your verification code is:
+{{verification_code}}
+This code expires in {{expires_in_minutes}} minutes.
+If you did not request this code, please ignore this email.
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 邮箱验证码", + HTML: notificationEmailCard("#4f46e5", "邮箱验证码", ` +{{recipient_name}},您好:
+您的验证码是:
+{{verification_code}}
+验证码将在 {{expires_in_minutes}} 分钟后失效。
+如果不是您本人操作,请忽略此邮件。
`), + }, + }, + NotificationEmailEventAuthPasswordReset: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Password reset request", + HTML: notificationEmailCard("#7c3aed", "Password reset", ` +Hello {{recipient_name}},
+We received a request to reset your password. Click the button below to set a new password.
+ +This link expires in {{expires_in_minutes}} minutes.
+If the button does not work, copy this link into your browser:
{{reset_url}}
If you did not request this, you can safely ignore this email.
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 密码重置请求", + HTML: notificationEmailCard("#7c3aed", "密码重置", ` +{{recipient_name}},您好:
+我们收到了您的密码重置请求,请点击下方按钮设置新密码。
+ +此链接将在 {{expires_in_minutes}} 分钟后失效。
+如果按钮无法点击,请复制以下链接到浏览器中打开:
{{reset_url}}
如果不是您本人操作,请忽略此邮件。
`), + }, + }, + NotificationEmailEventNotificationEmailVerifyCode: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Notification email verification code", + HTML: notificationEmailCard("#0ea5e9", "Notification email verification", ` +Hello {{recipient_name}},
+You are adding this address as an extra notification email.
+Your verification code is:
+{{verification_code}}
+This code expires in {{expires_in_minutes}} minutes.
+If you did not request this code, please ignore this email.
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 通知邮箱验证码", + HTML: notificationEmailCard("#0ea5e9", "通知邮箱验证", ` +{{recipient_name}},您好:
+您正在添加额外的通知邮箱,请输入以下验证码完成验证。
+{{verification_code}}
+验证码将在 {{expires_in_minutes}} 分钟后失效。
+如果不是您本人操作,请忽略此邮件。
`), + }, + }, + NotificationEmailEventSubscriptionPurchaseSuccess: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Subscription purchase successful", + HTML: notificationEmailCard("#2563eb", "Subscription activated", ` +Hello {{recipient_name}},
+Your subscription for {{subscription_group}} has been activated for {{subscription_days}} days.
+Expiry time: {{expiry_time}}
+Order ID: {{order_id}}
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 订阅购买成功", + HTML: notificationEmailCard("#2563eb", "订阅已开通", ` +{{recipient_name}},您好:
+您的 {{subscription_group}} 订阅已成功开通,有效期 {{subscription_days}} 天。
+到期时间:{{expiry_time}}
+订单号:{{order_id}}
`), + }, + }, + NotificationEmailEventSubscriptionExpiryReminder: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Subscription expires in {{days_remaining}} day(s)", + HTML: notificationEmailCard("#f97316", "Subscription expiry reminder", ` +Hello {{recipient_name}},
+Your {{subscription_group}} subscription will expire in {{days_remaining}} day(s).
+Expiry time: {{expiry_time}}
+Unsubscribe from optional subscription reminders
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 订阅将在 {{days_remaining}} 天后到期", + HTML: notificationEmailCard("#f97316", "订阅到期提醒", ` +{{recipient_name}},您好:
+您的 {{subscription_group}} 订阅将在 {{days_remaining}} 天后到期。
+到期时间:{{expiry_time}}
+`), + }, + }, + NotificationEmailEventBalanceLow: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Low balance alert", + HTML: notificationEmailCard("#d97706", "Low balance alert", ` +Hello {{recipient_name}},
+Your current balance is ${{current_balance}}, below the configured alert threshold of ${{threshold}}.
+Please recharge in time to avoid service interruption.
+ +Unsubscribe from optional balance alerts
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 余额不足提醒", + HTML: notificationEmailCard("#d97706", "余额不足提醒", ` +{{recipient_name}},您好:
+您当前余额为 ${{current_balance}},已低于提醒阈值 ${{threshold}}。
+请及时充值以免服务中断。
+ +`), + }, + }, + NotificationEmailEventBalanceRechargeSuccess: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Balance recharge successful", + HTML: notificationEmailCard("#16a34a", "Recharge successful", ` +Hello {{recipient_name}},
+Your balance recharge of ${{recharge_amount}} has been completed.
+Current balance: ${{current_balance}}
+Order ID: {{order_id}}
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 余额充值成功", + HTML: notificationEmailCard("#16a34a", "余额充值成功", ` +{{recipient_name}},您好:
+您的余额充值 ${{recharge_amount}} 已完成。
+当前余额:${{current_balance}}
+订单号:{{order_id}}
`), + }, + }, + NotificationEmailEventAccountQuotaAlert: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Account quota alert - {{account_name}}", + HTML: notificationEmailCard("#dc2626", "Account quota alert", ` +The upstream account {{account_name}} has crossed its configured quota alert threshold.
+| Account ID | {{account_id}} |
| Platform | {{platform}} |
| Dimension | {{quota_dimension}} |
| Used / Limit | {{quota_used}} / {{quota_limit}} |
| Remaining | {{quota_remaining}} |
| Threshold | {{quota_threshold}} |
上游账号 {{account_name}} 已触发配置的额度告警阈值。
+| 账号 ID | {{account_id}} |
| 平台 | {{platform}} |
| 维度 | {{quota_dimension}} |
| 已用 / 限额 | {{quota_used}} / {{quota_limit}} |
| 剩余额度 | {{quota_remaining}} |
| 告警阈值 | {{quota_threshold}} |
Hello {{recipient_name}},
+Your API request triggered the platform content moderation/risk-control policy.
+| Triggered at | {{triggered_at}} |
| Group | {{group_name}} |
| Category / Score | {{moderation_category}} / {{moderation_score}} |
| Violation count | {{violation_count}} / {{ban_threshold}} |
Please review your request content to avoid future service interruptions.
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 账户风控提醒", + HTML: notificationEmailCard("#ef4444", "账户风控提醒", ` +{{recipient_name}},您好:
+您的 API 请求触发了平台内容审核/风控策略。
+| 触发时间 | {{triggered_at}} |
| 所属分组 | {{group_name}} |
| 命中类别 / 分数 | {{moderation_category}} / {{moderation_score}} |
| 累计触发次数 | {{violation_count}} / {{ban_threshold}} |
请检查请求内容,避免后续服务受到影响。
`), + }, + }, + NotificationEmailEventContentModerationDisabled: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Account disabled by risk control", + HTML: notificationEmailCard("#b91c1c", "Account disabled", ` +Hello {{recipient_name}},
+Your account has repeatedly triggered platform content moderation/risk-control rules and has been automatically disabled.
+| Disabled at | {{triggered_at}} |
| Group | {{group_name}} |
| Category / Score | {{moderation_category}} / {{moderation_score}} |
| Violation count | {{violation_count}} / {{ban_threshold}} |
Please contact the administrator if you need to appeal or restore access.
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 账户已被禁用", + HTML: notificationEmailCard("#b91c1c", "账户已被禁用", ` +{{recipient_name}},您好:
+您的账户在统计周期内多次触发平台内容审核/风控规则,系统已自动禁用该账户。
+| 禁用时间 | {{triggered_at}} |
| 所属分组 | {{group_name}} |
| 命中类别 / 分数 | {{moderation_category}} / {{moderation_score}} |
| 累计触发次数 | {{violation_count}} / {{ban_threshold}} |
如需申诉或恢复账号,请联系平台管理员处理。
`), + }, + }, + NotificationEmailEventOpsAlert: { + notificationEmailDefaultLocale: { + Subject: "[Ops Alert][{{severity}}] {{rule_name}}", + HTML: notificationEmailCard("#ea580c", "Ops alert", ` +Rule: {{rule_name}}
+Severity: {{severity}}
+Status: {{alert_status}}
+Metric: {{metric_type}} {{operator}} {{metric_value}} (threshold {{threshold_value}})
+Fired at: {{triggered_at}}
+Description: {{alert_description}}
`), + }, + notificationEmailLocaleChinese: { + Subject: "[运维告警][{{severity}}] {{rule_name}}", + HTML: notificationEmailCard("#ea580c", "运维告警", ` +规则:{{rule_name}}
+严重级别:{{severity}}
+状态:{{alert_status}}
+指标:{{metric_type}} {{operator}} {{metric_value}}(阈值 {{threshold_value}})
+触发时间:{{triggered_at}}
+说明:{{alert_description}}
`), + }, + }, + NotificationEmailEventOpsScheduledReport: { + notificationEmailDefaultLocale: { + Subject: "[Ops Report] {{report_name}}", + HTML: notificationEmailCard("#0891b2", "Ops report", ` +Report: {{report_name}}
+Type: {{report_type}}
+Range: {{report_start_time}} - {{report_end_time}}
+报表:{{report_name}}
+类型:{{report_type}}
+时间范围:{{report_start_time}} - {{report_end_time}}
+{{recipient_name}}
Recharge`, + Variables: map[string]string{ + "recipient_name": ``, + "recharge_url": `javascript:alert(1)`, + }, + }) + require.NoError(t, err) + require.NotContains(t, preview.Subject, "\r") + require.NotContains(t, preview.Subject, "\n") + require.Contains(t, preview.Subject, `Low balance for Injected`) + require.Contains(t, preview.HTML, `<script>alert("x")</script>`) + require.NotContains(t, preview.HTML, `javascript:alert(1)`) + require.Contains(t, preview.HTML, `href=""`) +} + +func TestNotificationEmailTemplateOverrideAndRestore(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + + official, err := svc.GetTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "en") + require.NoError(t, err) + require.False(t, official.IsCustom) + + updated, err := svc.UpdateTemplate( + ctx, + NotificationEmailEventBalanceRechargeSuccess, + "zh-Hans", + "充值完成:{{recharge_amount}}", + "{{recipient_name}} 已充值 {{recharge_amount}}
", + ) + require.NoError(t, err) + require.True(t, updated.IsCustom) + require.Equal(t, "zh", updated.Locale) + require.Equal(t, "充值完成:{{recharge_amount}}", updated.Subject) + require.NotNil(t, updated.UpdatedAt) + + restored, err := svc.RestoreOfficialTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "zh") + require.NoError(t, err) + require.False(t, restored.IsCustom) + require.NotEqual(t, updated.Subject, restored.Subject) + _, err = repo.GetValue(ctx, notificationEmailTemplateKey(NotificationEmailEventBalanceRechargeSuccess, "zh")) + require.ErrorIs(t, err, ErrSettingNotFound) +} + +func TestNotificationEmailTemplateRejectsUnsupportedPlaceholder(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + _, err := svc.UpdateTemplate( + ctx, + NotificationEmailEventSubscriptionPurchaseSuccess, + "en", + "Purchased {{not_allowed}}", + "{{subscription_group}}
", + ) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported placeholder") +} + +func TestNotificationEmailAuthTemplatesAreListedAndPreviewable(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + infos := svc.ListEventInfos() + events := make(map[string]NotificationEmailEventInfo, len(infos)) + for _, info := range infos { + events[info.Event] = info + } + require.Contains(t, events, NotificationEmailEventAuthVerifyCode) + require.Contains(t, events, NotificationEmailEventAuthPasswordReset) + require.False(t, events[NotificationEmailEventAuthVerifyCode].Optional) + require.False(t, events[NotificationEmailEventAuthPasswordReset].Optional) + require.Contains(t, events[NotificationEmailEventAuthVerifyCode].Placeholders, "verification_code") + require.Contains(t, events[NotificationEmailEventAuthPasswordReset].Placeholders, "reset_url") + + verifyPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{ + Event: NotificationEmailEventAuthVerifyCode, + Locale: "zh-CN", + Variables: map[string]string{ + "verification_code": "654321", + "expires_in_minutes": "15", + }, + }) + require.NoError(t, err) + require.Contains(t, verifyPreview.Subject, "邮箱验证码") + require.Contains(t, verifyPreview.HTML, "654321") + + resetPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{ + Event: NotificationEmailEventAuthPasswordReset, + Locale: "en", + Variables: map[string]string{ + "reset_url": "https://example.com/reset?token=abc", + "expires_in_minutes": "30", + }, + }) + require.NoError(t, err) + require.Contains(t, resetPreview.Subject, "Password reset") + require.Contains(t, resetPreview.HTML, "https://example.com/reset?token=abc") +} + +func TestNotificationEmailAdditionalEventsAreListedAndPreviewable(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + infos := svc.ListEventInfos() + events := make(map[string]NotificationEmailEventInfo, len(infos)) + for _, info := range infos { + events[info.Event] = info + } + + checks := []struct { + event string + placeholder string + }{ + {NotificationEmailEventNotificationEmailVerifyCode, "verification_code"}, + {NotificationEmailEventAccountQuotaAlert, "account_name"}, + {NotificationEmailEventContentModerationViolation, "moderation_category"}, + {NotificationEmailEventContentModerationDisabled, "violation_count"}, + {NotificationEmailEventOpsAlert, "rule_name"}, + {NotificationEmailEventOpsScheduledReport, "report_html"}, + } + + for _, check := range checks { + info, ok := events[check.event] + require.Truef(t, ok, "expected %s to be listed", check.event) + require.False(t, info.Optional) + require.Contains(t, info.Placeholders, check.placeholder) + + preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{Event: check.event, Locale: "zh"}) + require.NoError(t, err) + require.NotEmpty(t, preview.Subject) + require.NotEmpty(t, preview.HTML) + } +} + +func TestNotificationEmailRawHTMLVariablesAreTrustedOnlyForHTMLPlaceholders(t *testing.T) { + require.True(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "report_html")) + require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "recipient_name")) + require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsAlert, "report_html")) + + preview, err := renderNotificationEmail( + NotificationEmailEventOpsScheduledReport, + "Report for {{recipient_name}}", + `{{recipient_name}}
`, + map[string]string{ + "recipient_name": ``, + "report_html": `escaped report
`, + }, + map[string]string{ + "report_html": `| trusted report |
| trusted report |
{{recipient_name}}
`, + map[string]string{"recipient_name": `escaped`}, + map[string]string{"recipient_name": `raw`}, + ) + require.NoError(t, err) + require.Contains(t, preview.HTML, `<em>escaped</em>`) + require.NotContains(t, preview.HTML, `raw`) +} + +func TestNotificationEmailFallbackClassification(t *testing.T) { + templateErr := notificationEmailTemplateErr(errors.New("bad template")) + configErr := notificationEmailConfigErr(errors.New("missing email service")) + deliveryErr := notificationEmailDeliveryErr(errors.New("smtp timeout")) + + require.True(t, shouldFallbackNotificationEmail(templateErr)) + require.True(t, shouldFallbackNotificationEmail(configErr)) + require.False(t, shouldFallbackNotificationEmail(deliveryErr)) + require.True(t, isNotificationEmailDeliveryError(deliveryErr)) + require.False(t, isNotificationEmailDeliveryError(templateErr)) + require.False(t, shouldFallbackNotificationEmail(nil)) +} + +func TestEmailQueueTasksPreserveLocaleHints(t *testing.T) { + queue := &EmailQueueService{taskChan: make(chan EmailTask, 2)} + require.NoError(t, queue.EnqueueVerifyCode("user@example.com", "Sub2API", "zh-CN")) + require.NoError(t, queue.EnqueuePasswordReset("user@example.com", "Sub2API", "https://example.com/reset", "en-US")) + + verifyTask := <-queue.taskChan + require.Equal(t, TaskTypeVerifyCode, verifyTask.TaskType) + require.Equal(t, "zh-CN", verifyTask.Locale) + + resetTask := <-queue.taskChan + require.Equal(t, TaskTypePasswordReset, resetTask.TaskType) + require.Equal(t, "en-US", resetTask.Locale) +} + +func TestOpsScheduledReportDeliverySourceIDIncludesReportIdentity(t *testing.T) { + report := &opsScheduledReport{Name: "日报", ReportType: "daily_summary", Schedule: "0 9 * * *"} + sourceID := opsScheduledReportDeliverySourceID(report) + require.Contains(t, sourceID, "daily_summary") + require.Contains(t, sourceID, "日报") + require.Contains(t, sourceID, "0 9 * * *") + require.NotEqual(t, sourceID, opsScheduledReportDeliverySourceID(&opsScheduledReport{Name: "周报", ReportType: "weekly_summary", Schedule: "0 9 * * 1"})) + require.Equal(t, "scheduled_report", opsScheduledReportDeliverySourceID(nil)) +} + +func TestNotificationEmailUnsubscribeOnlyAllowsOptionalEvents(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + token, err := svc.createUnsubscribeToken(ctx, "User@Example.com", NotificationEmailEventBalanceLow) + require.NoError(t, err) + result, err := svc.Unsubscribe(ctx, token) + require.NoError(t, err) + require.True(t, result.Done) + require.Equal(t, NotificationEmailEventBalanceLow, result.Event) + unsubscribed, err := svc.IsUnsubscribed(ctx, "user@example.com", NotificationEmailEventBalanceLow) + require.NoError(t, err) + require.True(t, unsubscribed) + + transactionalToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventBalanceRechargeSuccess) + require.NoError(t, err) + _, err = svc.Unsubscribe(ctx, transactionalToken) + require.Error(t, err) + require.Contains(t, err.Error(), "transactional") + + authToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventAuthVerifyCode) + require.NoError(t, err) + _, err = svc.Unsubscribe(ctx, authToken) + require.Error(t, err) + require.Contains(t, err.Error(), "transactional") +} + +func TestNotificationEmailLocaleMemoryNormalizesAcceptLanguage(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + svc.RememberRecipientLocale(ctx, 42, "User@Example.com", "zh-CN,zh;q=0.9,en;q=0.8") + require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 42, "user@example.com")) + require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 0, "user@example.com")) +} + +func TestNotificationEmailDeliveryKeyUsesShortStableHash(t *testing.T) { + key := notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "User@Example.com", + "7d", + ) + require.NotEmpty(t, key) + require.LessOrEqual(t, len(key), 100) + require.True(t, strings.HasPrefix(key, notificationEmailDeliveryKeyPrefix+"v2:")) + require.Equal(t, key, notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "7d", + )) + require.NotEqual(t, key, notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "3d", + )) + + legacyKey := legacyNotificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "7d", + ) + require.Greater(t, len(legacyKey), 100) +} + +func TestNotificationEmailPreferenceKeyUsesShortStableHashAndReadsLegacyKey(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + + key := notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "User@Example.com") + require.NotEmpty(t, key) + require.LessOrEqual(t, len(key), 100) + require.True(t, strings.HasPrefix(key, notificationEmailPreferenceKeyPrefix+"v2:")) + require.Equal(t, key, notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com")) + + legacyKey := legacyNotificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com") + require.Greater(t, len(legacyKey), 100) + require.NoError(t, repo.Set(ctx, legacyKey, "unsubscribed")) + + unsubscribed, err := svc.IsUnsubscribed(ctx, "User@Example.com", NotificationEmailEventSubscriptionExpiryReminder) + require.NoError(t, err) + require.True(t, unsubscribed) +} + +func TestNotificationEmailSendDeduplicatesSubscriptionExpiryReminder(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + smtpServer := startNotificationEmailTestSMTPServer(t) + require.NoError(t, repo.SetMultiple(ctx, smtpServer.settings())) + + emailSvc := NewEmailService(repo, nil) + svc := NewNotificationEmailService(repo, emailSvc) + input := NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: "User@Example.com", + RecipientName: "User", + UserID: 42, + SourceType: "user_subscription", + SourceID: "1234567890", + ReminderKey: "7d", + Variables: map[string]string{ + "subscription_group": "Codex", + "expiry_time": "2026-05-27 12:00", + "days_remaining": "7", + }, + } + + require.NoError(t, svc.Send(ctx, input)) + require.Equal(t, int64(1), smtpServer.messageCount()) + + key := notificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey) + require.LessOrEqual(t, len(key), 100) + _, err := repo.GetValue(ctx, key) + require.NoError(t, err) + + require.NoError(t, svc.Send(ctx, input)) + require.Equal(t, int64(1), smtpServer.messageCount()) +} + +func TestNotificationEmailSendRespectsLegacyDeliveryKey(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + input := NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: "user@example.com", + SourceType: "user_subscription", + SourceID: "1234567890", + ReminderKey: "7d", + } + legacyKey := legacyNotificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey) + require.NoError(t, repo.Set(ctx, legacyKey, "sent")) + + require.NoError(t, svc.Send(ctx, input)) +} + +type notificationEmailMemorySettingRepo struct { + mu sync.RWMutex + values map[string]string +} + +func newNotificationEmailMemorySettingRepo() *notificationEmailMemorySettingRepo { + return ¬ificationEmailMemorySettingRepo{values: make(map[string]string)} +} + +func (r *notificationEmailMemorySettingRepo) Get(_ context.Context, key string) (*Setting, error) { + r.mu.RLock() + defer r.mu.RUnlock() + value, ok := r.values[key] + if !ok { + return nil, ErrSettingNotFound + } + return &Setting{Key: key, Value: value}, nil +} + +func (r *notificationEmailMemorySettingRepo) GetValue(ctx context.Context, key string) (string, error) { + setting, err := r.Get(ctx, key) + if err != nil { + return "", err + } + return setting.Value, nil +} + +func (r *notificationEmailMemorySettingRepo) Set(_ context.Context, key, value string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.values[key] = value + return nil +} + +func (r *notificationEmailMemorySettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *notificationEmailMemorySettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + r.mu.Lock() + defer r.mu.Unlock() + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *notificationEmailMemorySettingRepo) GetAll(_ context.Context) (map[string]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *notificationEmailMemorySettingRepo) Delete(_ context.Context, key string) error { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.values[key]; !ok { + return ErrSettingNotFound + } + delete(r.values, key) + return nil +} + +func TestNotificationEmailMemorySettingRepoSatisfiesInterface(t *testing.T) { + var _ SettingRepository = (*notificationEmailMemorySettingRepo)(nil) + require.False(t, strings.Contains(notificationEmailPreferenceKey(NotificationEmailEventBalanceLow, "User@Example.com"), "User@Example.com")) +} + +type notificationEmailTestSMTPServer struct { + listener net.Listener + wg sync.WaitGroup + messages atomic.Int64 +} + +func startNotificationEmailTestSMTPServer(t *testing.T) *notificationEmailTestSMTPServer { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + server := ¬ificationEmailTestSMTPServer{listener: listener} + server.wg.Add(1) + go server.serve() + t.Cleanup(server.close) + return server +} + +func (s *notificationEmailTestSMTPServer) settings() map[string]string { + host, port, _ := net.SplitHostPort(s.listener.Addr().String()) + return map[string]string{ + SettingKeySMTPHost: host, + SettingKeySMTPPort: port, + SettingKeySMTPUsername: "user", + SettingKeySMTPPassword: "password", + SettingKeySMTPFrom: "noreply@example.com", + SettingKeySMTPFromName: "Sub2API", + SettingKeySMTPUseTLS: "false", + } +} + +func (s *notificationEmailTestSMTPServer) messageCount() int64 { + return s.messages.Load() +} + +func (s *notificationEmailTestSMTPServer) close() { + _ = s.listener.Close() + s.wg.Wait() +} + +func (s *notificationEmailTestSMTPServer) serve() { + defer s.wg.Done() + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + s.handleConn(conn) + } +} + +func (s *notificationEmailTestSMTPServer) handleConn(conn net.Conn) { + defer func() { _ = conn.Close() }() + rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + writeLine := func(line string) bool { + if _, err := rw.WriteString(line + "\r\n"); err != nil { + return false + } + return rw.Flush() == nil + } + if !writeLine("220 localhost ESMTP") { + return + } + for { + line, err := rw.ReadString('\n') + if err != nil { + return + } + cmd := strings.ToUpper(strings.TrimRight(line, "\r\n")) + switch { + case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"): + if _, err := rw.WriteString("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil { + return + } + if err := rw.Flush(); err != nil { + return + } + case strings.HasPrefix(cmd, "AUTH"): + if !writeLine("235 2.7.0 Authentication successful") { + return + } + case strings.HasPrefix(cmd, "MAIL FROM:"): + if !writeLine("250 2.1.0 OK") { + return + } + case strings.HasPrefix(cmd, "RCPT TO:"): + if !writeLine("250 2.1.5 OK") { + return + } + case strings.HasPrefix(cmd, "DATA"): + if !writeLine("354 End data with| {{ t('keyUsage.date') }} | +{{ t('keyUsage.requests') }} | +{{ t('keyUsage.inputTokens') }} | +{{ t('keyUsage.outputTokens') }} | +{{ t('keyUsage.cacheReadTokens') }} | +{{ t('keyUsage.cacheWriteTokens') }} | +{{ t('keyUsage.cost') }} | +
|---|---|---|---|---|---|---|
| {{ row.date }} | +{{ fmtNum(row.requests) }} | +{{ fmtNum(row.input_tokens) }} | +{{ fmtNum(row.output_tokens) }} | +{{ fmtNum(row.cache_read_tokens) }} | +{{ fmtNum(row.cache_write_tokens) }} | +{{ usd(row.actual_cost != null ? row.actual_cost : row.cost) }} | +
+ {{ + t( + "admin.settings.gatewayForwarding.openaiCodexUserAgentHint", + ) + }} +
++ {{ t("admin.settings.emailTemplates.description") }} +
++ {{ selectedEventDescription }} +
++ {{ t("admin.settings.emailTemplates.placeholdersHelp") }} +
++ {{ t("admin.settings.emailTemplates.previewSecurityHint") }} +
+