diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 74799d81..f3185747 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.127 +0.1.128 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 9afc20ec..af535d5c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -189,7 +189,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) + notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) + balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) rpmTokenBucketService := service.NewRPMTokenBucketService() gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) @@ -204,8 +205,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) registry := payment.ProvideRegistry() defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService) + paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService, notificationEmailService) + settingHandler := handler.ProvideAdminSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService, notificationEmailService) requestEventBus := service.NewRequestEventBus() opsLogBroadcaster := service.ProvideOpsLogBroadcaster() opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster) @@ -253,7 +254,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig) - handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) + handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo, notificationEmailService) totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) @@ -274,7 +275,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) + subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 9907d441..14f5dce0 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -56,13 +56,14 @@ func firstNonEmpty(values ...string) string { // SettingHandler 系统设置处理器 type SettingHandler struct { - settingService *service.SettingService - emailService *service.EmailService - turnstileService *service.TurnstileService - opsService *service.OpsService - paymentConfigService *service.PaymentConfigService - paymentService *service.PaymentService - userAttributeService *service.UserAttributeService + settingService *service.SettingService + emailService *service.EmailService + turnstileService *service.TurnstileService + opsService *service.OpsService + paymentConfigService *service.PaymentConfigService + paymentService *service.PaymentService + userAttributeService *service.UserAttributeService + notificationEmailService *service.NotificationEmailService } // NewSettingHandler 创建系统设置处理器 @@ -78,6 +79,12 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser } } +// SetNotificationEmailService attaches the notification template service without changing +// the constructor signature used by existing unit tests. +func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) { + h.notificationEmailService = notificationEmailService +} + // GetSettings 获取所有系统设置 // GET /api/v1/admin/settings func (h *SettingHandler) GetSettings(c *gin.Context) { @@ -247,6 +254,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, RewriteMessageCacheControl: settings.RewriteMessageCacheControl, AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion, + OpenAICodexUserAgent: settings.OpenAICodexUserAgent, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, @@ -563,6 +571,7 @@ type UpdateSettingsRequest struct { EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"` RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"` AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"` + OpenAICodexUserAgent *string `json:"openai_codex_user_agent"` // Payment visible method routing PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` @@ -1404,6 +1413,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } } + if req.OpenAICodexUserAgent != nil { + normalized := strings.TrimSpace(*req.OpenAICodexUserAgent) + req.OpenAICodexUserAgent = &normalized + // 仅做长度上限保护,不限制具体格式(运维需要可自由调整 codex 版本号) + if len(normalized) > 512 { + response.Error(c, http.StatusBadRequest, "openai_codex_user_agent must be at most 512 characters") + return + } + } // 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号 if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" { @@ -1597,6 +1615,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AntigravityUserAgentVersion }(), + OpenAICodexUserAgent: func() string { + if req.OpenAICodexUserAgent != nil { + return *req.OpenAICodexUserAgent + } + return previousSettings.OpenAICodexUserAgent + }(), PaymentVisibleMethodAlipaySource: func() string { if req.PaymentVisibleMethodAlipaySource != nil { return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) @@ -1956,6 +1980,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection, RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl, AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion, + OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent, PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, @@ -2411,6 +2436,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion { changed = append(changed, "antigravity_user_agent_version") } + if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent { + changed = append(changed, "openai_codex_user_agent") + } if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { changed = append(changed, "payment_visible_method_alipay_source") } @@ -3339,3 +3367,160 @@ func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, } slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType) } + +// ListEmailTemplates returns all editable notification email templates. +// GET /api/v1/admin/settings/email-templates +func (h *SettingHandler) ListEmailTemplates(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + events := h.notificationEmailService.ListEventInfos() + templates, err := h.notificationEmailService.ListTemplates(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, dto.EmailTemplateListResponse{ + Events: emailTemplateEventOptionsToDTO(events), + Locales: h.notificationEmailService.SupportedLocales(), + Templates: emailTemplateSummariesToDTO(templates), + Placeholders: emailTemplatePlaceholderUnion(events), + }) +} + +// GetEmailTemplate returns one editable notification email template. +// GET /api/v1/admin/settings/email-templates/:event/:locale +func (h *SettingHandler) GetEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + tmpl, err := h.notificationEmailService.GetTemplate(c.Request.Context(), c.Param("event"), c.Param("locale")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, emailTemplateDetailToDTO(tmpl)) +} + +// UpdateEmailTemplate saves an override for one event/locale template. +// PUT /api/v1/admin/settings/email-templates/:event/:locale +func (h *SettingHandler) UpdateEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + var req dto.UpdateEmailTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + tmpl, err := h.notificationEmailService.UpdateTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"), req.Subject, req.HTML) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, emailTemplateDetailToDTO(tmpl)) +} + +// RestoreOfficialEmailTemplate removes an override and returns the built-in template. +// POST /api/v1/admin/settings/email-templates/:event/:locale/restore-official +func (h *SettingHandler) RestoreOfficialEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + tmpl, err := h.notificationEmailService.RestoreOfficialTemplate(c.Request.Context(), c.Param("event"), c.Param("locale")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, emailTemplateDetailToDTO(tmpl)) +} + +// PreviewEmailTemplate renders a template with safe sample variables without saving it. +// POST /api/v1/admin/settings/email-templates/preview +func (h *SettingHandler) PreviewEmailTemplate(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + var req dto.PreviewEmailTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + preview, err := h.notificationEmailService.PreviewTemplate(c.Request.Context(), service.NotificationEmailPreviewInput{ + Event: req.Event, + Locale: req.Locale, + Subject: req.Subject, + HTML: req.HTML, + Variables: req.Variables, + }) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + response.Success(c, dto.EmailTemplatePreviewResponse{Subject: preview.Subject, HTML: preview.HTML}) +} + +func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo) []dto.EmailTemplateEventOption { + items := make([]dto.EmailTemplateEventOption, 0, len(events)) + for _, event := range events { + items = append(items, dto.EmailTemplateEventOption{ + Value: event.Event, + Label: event.Label, + Description: event.Description, + }) + } + return items +} + +func emailTemplateSummariesToDTO(templates []service.NotificationEmailTemplate) []dto.EmailTemplateSummary { + items := make([]dto.EmailTemplateSummary, 0, len(templates)) + for _, tmpl := range templates { + items = append(items, dto.EmailTemplateSummary{ + Event: tmpl.Event, + Locale: tmpl.Locale, + Subject: tmpl.Subject, + IsCustom: tmpl.IsCustom, + UpdatedAt: emailTemplateUpdatedAt(tmpl), + }) + } + return items +} + +func emailTemplateDetailToDTO(tmpl service.NotificationEmailTemplate) dto.EmailTemplateDetail { + return dto.EmailTemplateDetail{ + Event: tmpl.Event, + Locale: tmpl.Locale, + Subject: tmpl.Subject, + HTML: tmpl.HTML, + IsCustom: tmpl.IsCustom, + UpdatedAt: emailTemplateUpdatedAt(tmpl), + Placeholders: tmpl.Placeholders, + } +} + +func emailTemplateUpdatedAt(tmpl service.NotificationEmailTemplate) string { + if tmpl.UpdatedAt == nil { + return "" + } + return tmpl.UpdatedAt.Format("2006-01-02T15:04:05Z07:00") +} + +func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) []string { + seen := make(map[string]struct{}) + placeholders := make([]string, 0) + for _, event := range events { + for _, placeholder := range event.Placeholders { + if _, ok := seen[placeholder]; ok { + continue + } + seen[placeholder] = struct{}{} + placeholders = append(placeholders, placeholder) + } + } + return placeholders +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index a9af910d..592a0d82 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -203,7 +203,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) { return } - result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email) + result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email, c.GetHeader("Accept-Language")) if err != nil { response.ErrorFrom(c, err) return @@ -602,7 +602,7 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { // Request password reset (async) // Note: This returns success even if email doesn't exist (to prevent enumeration) - if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil { + if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL, c.GetHeader("Accept-Language")); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 1014a3e8..550363fd 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -545,7 +545,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { return } - result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email) + result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email, c.GetHeader("Accept-Language")) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 45ad7a70..bdad5572 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -181,6 +181,7 @@ type SystemSettings struct { EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"` RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"` AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"` + OpenAICodexUserAgent string `json:"openai_codex_user_agent"` // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` @@ -377,6 +378,62 @@ type OpenAIFastPolicySettings struct { Rules []OpenAIFastPolicyRule `json:"rules"` } +// EmailTemplateEventOption describes an editable notification email event. +type EmailTemplateEventOption struct { + Value string `json:"value"` + Label string `json:"label,omitempty"` + Description string `json:"description,omitempty"` +} + +// EmailTemplateSummary is shown in the admin email template list. +type EmailTemplateSummary struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + IsCustom bool `json:"is_custom,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +// EmailTemplateListResponse is returned by GET /admin/settings/email-templates. +type EmailTemplateListResponse struct { + Events []EmailTemplateEventOption `json:"events"` + Locales []string `json:"locales"` + Templates []EmailTemplateSummary `json:"templates,omitempty"` + Placeholders []string `json:"placeholders,omitempty"` +} + +// EmailTemplateDetail is returned for a specific event/locale template. +type EmailTemplateDetail struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + IsCustom bool `json:"is_custom,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + Placeholders []string `json:"placeholders,omitempty"` +} + +// UpdateEmailTemplateRequest updates a template override. +type UpdateEmailTemplateRequest struct { + Subject string `json:"subject"` + HTML string `json:"html"` +} + +// PreviewEmailTemplateRequest previews a template without saving it. +type PreviewEmailTemplateRequest struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + Variables map[string]string `json:"variables,omitempty"` +} + +// EmailTemplatePreviewResponse is the rendered preview payload. +type EmailTemplatePreviewResponse struct { + Subject string `json:"subject"` + HTML string `json:"html"` +} + // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // Returns empty slice on empty/invalid input. func ParseCustomMenuItems(raw string) []CustomMenuItem { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 800328bb..406ed819 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1133,9 +1133,15 @@ func (h *GatewayHandler) Usage(c *gin.Context) { // 解析可选的日期范围参数(用于 model_stats 查询) startTime, endTime := h.parseUsageDateRange(c) + days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", "")) + if !ok { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Invalid days, allowed range is 1-90") + return + } // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 usageData := h.buildUsageData(ctx, apiKey.ID) + dailyUsage := h.buildAPIKeyDailyUsage(c, subject.UserID, apiKey.ID, days) // Best-effort: 获取模型统计 var modelStats any @@ -1149,11 +1155,11 @@ func (h *GatewayHandler) Usage(c *gin.Context) { isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits() if isQuotaLimited { - h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats) + h.usageQuotaLimited(c, ctx, apiKey, usageData, dailyUsage, modelStats) return } - h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats) + h.usageUnrestricted(c, ctx, apiKey, subject, usageData, dailyUsage, modelStats) } // parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围 @@ -1211,8 +1217,20 @@ func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin } } +func (h *GatewayHandler) buildAPIKeyDailyUsage(c *gin.Context, userID, apiKeyID int64, days int) any { + if h.usageService == nil { + return nil + } + startTime, endTime := apiKeyDailyUsageRange(days, c.Query("timezone")) + stats, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), userID, apiKeyID, startTime, endTime) + if err != nil { + return nil + } + return stats +} + // usageQuotaLimited 处理 quota_limited 模式的响应 -func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) { +func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, dailyUsage any, modelStats any) { resp := gin.H{ "mode": "quota_limited", "isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired, @@ -1294,6 +1312,9 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, if usageData != nil { resp["usage"] = usageData } + if dailyUsage != nil { + resp["daily_usage"] = dailyUsage + } if modelStats != nil { resp["model_stats"] = modelStats } @@ -1302,7 +1323,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, } // usageUnrestricted 处理 unrestricted 模式的响应(向后兼容) -func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) { +func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, dailyUsage any, modelStats any) { // 订阅模式 if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { resp := gin.H{ @@ -1331,6 +1352,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, if usageData != nil { resp["usage"] = usageData } + if dailyUsage != nil { + resp["daily_usage"] = dailyUsage + } if modelStats != nil { resp["model_stats"] = modelStats } @@ -1356,6 +1380,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, if usageData != nil { resp["usage"] = usageData } + if dailyUsage != nil { + resp["daily_usage"] = dailyUsage + } if modelStats != nil { resp["model_stats"] = modelStats } diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 1bb81190..f3c16f5d 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -266,6 +266,7 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { PaymentSource: req.PaymentSource, OrderType: req.OrderType, PlanID: req.PlanID, + Locale: c.GetHeader("Accept-Language"), }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index c4ba43e4..7413b840 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -1,6 +1,10 @@ package handler import ( + "html" + "net/http" + "strings" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -10,8 +14,9 @@ import ( // SettingHandler 公开设置处理器(无需认证) type SettingHandler struct { - settingService *service.SettingService - version string + settingService *service.SettingService + notificationEmailService *service.NotificationEmailService + version string } // NewSettingHandler 创建公开设置处理器 @@ -22,6 +27,12 @@ func NewSettingHandler(settingService *service.SettingService, version string) * } } +// SetNotificationEmailService attaches the public notification email service without +// changing the constructor signature used by existing tests. +func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) { + h.notificationEmailService = notificationEmailService +} + // GetPublicSettings 获取公开设置 // GET /api/v1/settings/public func (h *SettingHandler) GetPublicSettings(c *gin.Context) { @@ -90,6 +101,27 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { }) } +// UnsubscribeNotificationEmail handles optional notification email opt-outs. +// GET /api/v1/settings/email-unsubscribe?token=... +func (h *SettingHandler) UnsubscribeNotificationEmail(c *gin.Context) { + if h.notificationEmailService == nil { + response.InternalError(c, "notification email service is not configured") + return + } + token := strings.TrimSpace(c.Query("token")) + if token == "" { + response.BadRequest(c, "token is required") + return + } + result, err := h.notificationEmailService.Unsubscribe(c.Request.Context(), token) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + body := "Unsubscribed

Unsubscribed

You have unsubscribed " + html.EscapeString(result.Email) + " from " + html.EscapeString(result.Event) + " emails.

" + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(body)) +} + func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument { result := make([]dto.LoginAgreementDocument, 0, len(items)) for _, item := range items { diff --git a/backend/internal/handler/totp_handler.go b/backend/internal/handler/totp_handler.go index 5c5eb567..f9151dab 100644 --- a/backend/internal/handler/totp_handler.go +++ b/backend/internal/handler/totp_handler.go @@ -172,7 +172,7 @@ func (h *TotpHandler) SendVerifyCode(c *gin.Context) { return } - if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil { + if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID, c.GetHeader("Accept-Language")); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index b8506154..daa5695d 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -298,6 +298,29 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { return startTime, endTime } +const ( + defaultAPIKeyDailyUsageDays = 30 + maxAPIKeyDailyUsageDays = 90 +) + +func parseAPIKeyDailyUsageDays(raw string) (int, bool) { + if strings.TrimSpace(raw) == "" { + return defaultAPIKeyDailyUsageDays, true + } + days, err := strconv.Atoi(raw) + if err != nil || days <= 0 || days > maxAPIKeyDailyUsageDays { + return 0, false + } + return days, true +} + +func apiKeyDailyUsageRange(days int, userTZ string) (time.Time, time.Time) { + now := timezone.NowInUserLocation(userTZ) + startTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -(days-1)), userTZ) + endTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ) + return startTime, endTime +} + // DashboardStats handles getting user dashboard statistics // GET /api/v1/usage/dashboard/stats func (h *UsageHandler) DashboardStats(c *gin.Context) { @@ -416,3 +439,55 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { response.Success(c, gin.H{"stats": stats}) } + +// GetMyAPIKeyDailyUsage handles getting daily usage details for the current user's API key. +// GET /api/v1/user/api-keys/:id/usage/daily?days=30 +func (h *UsageHandler) GetMyAPIKeyDailyUsage(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + apiKeyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid API key ID") + return + } + + days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", "")) + if !ok { + response.BadRequest(c, "Invalid days, allowed range is 1-90") + return + } + + if h.apiKeyService == nil { + response.InternalError(c, "API key service is not configured") + return + } + + apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), apiKeyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if apiKey.UserID != subject.UserID { + response.Forbidden(c, "Not authorized to access this API key's usage") + return + } + + userTZ := c.Query("timezone") + startTime, endTime := apiKeyDailyUsageRange(days, userTZ) + items, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), subject.UserID, apiKeyID, startTime, endTime) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "items": items, + "days": days, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.AddDate(0, 0, -1).Format("2006-01-02"), + }) +} diff --git a/backend/internal/handler/usage_handler_daily_test.go b/backend/internal/handler/usage_handler_daily_test.go new file mode 100644 index 00000000..36311fac --- /dev/null +++ b/backend/internal/handler/usage_handler_daily_test.go @@ -0,0 +1,195 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dailyUsageRepoStub struct { + service.UsageLogRepository + trend []usagestats.TrendDataPoint + + called bool + startTime time.Time + endTime time.Time + granularity string + userID int64 + apiKeyID int64 +} + +func (s *dailyUsageRepoStub) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + s.called = true + s.startTime = startTime + s.endTime = endTime + s.granularity = granularity + s.userID = userID + s.apiKeyID = apiKeyID + return s.trend, nil +} + +type dailyUsageAPIKeyRepoStub struct { + service.APIKeyRepository + keys map[int64]*service.APIKey +} + +func (s *dailyUsageAPIKeyRepoStub) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { + key, ok := s.keys[id] + if !ok { + return nil, service.ErrAPIKeyNotFound + } + clone := *key + return &clone, nil +} + +func newDailyUsageTestRouter(usageRepo *dailyUsageRepoStub, apiKeyRepo *dailyUsageAPIKeyRepoStub, userID int64) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(usageRepo, nil, nil, nil) + apiKeySvc := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, nil) + handler := NewUsageHandler(usageSvc, apiKeySvc) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: userID}) + c.Next() + }) + router.GET("/user/api-keys/:id/usage/daily", handler.GetMyAPIKeyDailyUsage) + return router +} + +type dailyUsageHandlerResponse struct { + Code int `json:"code"` + Data struct { + Items []usagestats.APIKeyDailyUsagePoint `json:"items"` + Days int `json:"days"` + } `json:"data"` +} + +func TestGetMyAPIKeyDailyUsageRejectsCrossUserAccess(t *testing.T) { + usageRepo := &dailyUsageRepoStub{} + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 99, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=30", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.False(t, usageRepo.called) +} + +func TestGetMyAPIKeyDailyUsageRejectsInvalidDays(t *testing.T) { + for _, path := range []string{ + "/user/api-keys/7/usage/daily?days=0", + "/user/api-keys/7/usage/daily?days=91", + } { + t.Run(path, func(t *testing.T) { + usageRepo := &dailyUsageRepoStub{} + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, path, nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.False(t, usageRepo.called) + }) + } +} + +func TestGetMyAPIKeyDailyUsageReturnsEmptyData(t *testing.T) { + usageRepo := &dailyUsageRepoStub{trend: []usagestats.TrendDataPoint{}} + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var got dailyUsageHandlerResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, 30, got.Data.Days) + require.Empty(t, got.Data.Items) +} + +func TestGetMyAPIKeyDailyUsageAggregatesByDayForOwnedKey(t *testing.T) { + usageRepo := &dailyUsageRepoStub{ + trend: []usagestats.TrendDataPoint{ + { + Date: "2026-05-19", + Requests: 3, + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 4, + CacheReadTokens: 6, + TotalTokens: 40, + Cost: 0.5, + ActualCost: 0.4, + }, + }, + } + apiKeyRepo := &dailyUsageAPIKeyRepoStub{ + keys: map[int64]*service.APIKey{ + 7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive}, + }, + } + router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42) + + req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=7", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, usageRepo.called) + require.Equal(t, "day", usageRepo.granularity) + require.Equal(t, int64(42), usageRepo.userID) + require.Equal(t, int64(7), usageRepo.apiKeyID) + require.True(t, usageRepo.startTime.Before(usageRepo.endTime)) + + var got dailyUsageHandlerResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, 7, got.Data.Days) + require.Len(t, got.Data.Items, 1) + require.Equal(t, usagestats.APIKeyDailyUsagePoint{ + Date: "2026-05-19", + Requests: 3, + InputTokens: 10, + OutputTokens: 20, + CacheReadTokens: 6, + CacheWriteTokens: 4, + TotalTokens: 40, + Cost: 0.5, + ActualCost: 0.4, + }, got.Data.Items[0]) +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index f1dbf4e1..95cb1482 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -335,7 +335,7 @@ func (h *UserHandler) SendEmailBindingCode(c *gin.Context) { return } - if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil { + if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email, c.GetHeader("Accept-Language")); err != nil { response.ErrorFrom(c, err) return } @@ -363,7 +363,7 @@ func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) { return } - err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache) + err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache, c.GetHeader("Accept-Language")) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index cb4ab0a4..1b74d873 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -90,8 +90,17 @@ func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService } // ProvideSettingHandler creates SettingHandler with version from BuildInfo -func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler { - return NewSettingHandler(settingService, buildInfo.Version) +func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo, notificationEmailService *service.NotificationEmailService) *SettingHandler { + h := NewSettingHandler(settingService, buildInfo.Version) + h.SetNotificationEmailService(notificationEmailService) + return h +} + +// ProvideAdminSettingHandler creates admin.SettingHandler with notification template APIs. +func ProvideAdminSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService, notificationEmailService *service.NotificationEmailService) *admin.SettingHandler { + h := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService) + h.SetNotificationEmailService(notificationEmailService) + return h } // ProvideHandlers creates the Handlers struct @@ -169,7 +178,7 @@ var ProviderSet = wire.NewSet( admin.NewProxyHandler, admin.NewRedeemHandler, admin.NewPromoHandler, - admin.NewSettingHandler, + ProvideAdminSettingHandler, admin.NewOpsHandler, ProvideSystemHandler, admin.NewSubscriptionHandler, diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go new file mode 100644 index 00000000..8fb82ef4 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go @@ -0,0 +1,719 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// ResponsesToChatCompletionsRequest converts a Responses API request into a +// Chat Completions request for upstreams that only implement +// /v1/chat/completions. +func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) { + if req == nil { + return nil, fmt.Errorf("responses request is nil") + } + + messages, err := responsesInputToChatMessages(req.Instructions, req.Input) + if err != nil { + return nil, err + } + + out := &ChatCompletionsRequest{ + Model: req.Model, + Messages: messages, + MaxCompletionTokens: req.MaxOutputTokens, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + ServiceTier: req.ServiceTier, + } + if req.Reasoning != nil { + out.ReasoningEffort = req.Reasoning.Effort + } + if len(req.Tools) > 0 { + out.Tools = responsesToolsToChatTools(req.Tools) + } + if len(req.ToolChoice) > 0 { + out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice) + } + + return out, nil +} + +func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) { + var messages []ChatMessage + if strings.TrimSpace(instructions) != "" { + content, _ := json.Marshal(instructions) + messages = append(messages, ChatMessage{ + Role: "system", + Content: content, + }) + } + + inputRaw = bytesTrimSpace(inputRaw) + if len(inputRaw) == 0 || string(inputRaw) == "null" { + return messages, nil + } + + var inputText string + if err := json.Unmarshal(inputRaw, &inputText); err == nil { + content, _ := json.Marshal(inputText) + messages = append(messages, ChatMessage{ + Role: "user", + Content: content, + }) + return messages, nil + } + + var rawItems []json.RawMessage + if err := json.Unmarshal(inputRaw, &rawItems); err != nil { + return nil, fmt.Errorf("parse responses input: %w", err) + } + + for _, raw := range rawItems { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + continue + } + + var item map[string]json.RawMessage + if err := json.Unmarshal(raw, &item); err != nil { + var text string + if textErr := json.Unmarshal(raw, &text); textErr == nil { + content, _ := json.Marshal(text) + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + } + return nil, fmt.Errorf("parse responses input item: %w", err) + } + + role := rawString(item["role"]) + itemType := rawString(item["type"]) + switch itemType { + case "function_call": + arguments := rawString(item["arguments"]) + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + messages = append(messages, ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{{ + ID: rawString(item["call_id"]), + Type: "function", + Function: ChatFunctionCall{ + Name: rawString(item["name"]), + Arguments: arguments, + }, + }}, + }) + continue + case "function_call_output": + content, _ := json.Marshal(rawString(item["output"])) + messages = append(messages, ChatMessage{ + Role: "tool", + ToolCallID: rawString(item["call_id"]), + Content: content, + }) + continue + case "input_text", "text": + content, _ := json.Marshal(rawString(item["text"])) + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + case "input_image": + content, err := chatContentFromSingleResponsesPart(itemType, item) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + } + + if role == "" { + role = "user" + } + content := item["content"] + if len(bytesTrimSpace(content)) == 0 { + if text := rawString(item["text"]); text != "" { + content, _ = json.Marshal(text) + } + } + chatContent, err := responsesContentToChatContent(content, role) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{ + Role: role, + Content: chatContent, + }) + } + + return messages, nil +} + +func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + empty, _ := json.Marshal("") + return empty, nil + } + + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return raw, nil + } + + var rawParts []json.RawMessage + if err := json.Unmarshal(raw, &rawParts); err == nil { + return responsesContentPartsToChatContent(rawParts, role) + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err == nil { + return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj) + } + + return raw, nil +} + +func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) { + var textParts []string + var chatParts []ChatContentPart + hasNonText := false + + for _, rawPart := range rawParts { + var part map[string]json.RawMessage + if err := json.Unmarshal(rawPart, &part); err != nil { + continue + } + partType := rawString(part["type"]) + switch partType { + case "input_text", "output_text", "text", "": + text := rawString(part["text"]) + if text == "" { + continue + } + textParts = append(textParts, text) + chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text}) + case "input_image", "image_url": + imageURL := rawString(part["image_url"]) + if imageURL == "" { + imageURL = rawNestedString(part["image_url"], "url") + } + if imageURL == "" { + continue + } + hasNonText = true + chatParts = append(chatParts, ChatContentPart{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: imageURL}, + }) + } + } + + if !hasNonText { + joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) + return joined, nil + } + if role != "user" { + joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) + return joined, nil + } + if len(chatParts) == 0 { + empty, _ := json.Marshal("") + return empty, nil + } + return json.Marshal(chatParts) +} + +func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) { + switch partType { + case "input_image", "image_url": + imageURL := rawString(part["image_url"]) + if imageURL == "" { + imageURL = rawNestedString(part["image_url"], "url") + } + return json.Marshal([]ChatContentPart{{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: imageURL}, + }}) + default: + return json.Marshal(rawString(part["text"])) + } +} + +func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool { + out := make([]ChatTool, 0, len(tools)) + for _, tool := range tools { + if tool.Type != "function" { + continue + } + out = append(out, ChatTool{ + Type: "function", + Function: &ChatFunction{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + Strict: tool.Strict, + }, + }) + } + return out +} + +func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage { + var choice map[string]json.RawMessage + if err := json.Unmarshal(raw, &choice); err != nil { + return raw + } + if rawString(choice["type"]) != "function" { + return raw + } + name := rawString(choice["name"]) + if name == "" { + name = rawNestedString(choice["function"], "name") + } + if name == "" { + return raw + } + out, err := json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{ + "name": name, + }, + }) + if err != nil { + return raw + } + return out +} + +// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions +// response into a Responses API response. +func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse { + id := "" + if resp != nil { + id = resp.ID + } + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: model, + Status: "completed", + } + if resp == nil { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + return out + } + if out.Model == "" { + out.Model = resp.Model + } + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + out.Output = chatMessageToResponsesOutput(choice.Message) + if choice.FinishReason == "length" { + out.Status = "incomplete" + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + } + if len(out.Output) == 0 { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + } + if resp.Usage != nil { + out.Usage = ChatUsageToResponsesUsage(resp.Usage) + } + return out +} + +func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput { + var outputs []ResponsesOutput + if message.ReasoningContent != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: message.ReasoningContent, + }}, + }) + } + + text := chatMessageContentText(message.Content) + if text != "" || len(message.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: text, + }}, + Status: "completed", + }) + } + + for _, toolCall := range message.ToolCalls { + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + + return outputs +} + +func emptyResponsesMessageOutput() ResponsesOutput { + return ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + } +} + +func chatMessageContentText(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return text + } + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, part := range parts { + if part.Type == "text" && part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses +// usage shape. +func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage { + if usage == nil { + return nil + } + out := &ResponsesUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + } + if out.TotalTokens == 0 { + out.TotalTokens = out.InputTokens + out.OutputTokens + } + if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 { + out.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: usage.PromptTokensDetails.CachedTokens, + } + } + return out +} + +// ChatCompletionsToResponsesStreamState tracks state while converting Chat +// Completions SSE chunks into Responses SSE events. +type ChatCompletionsToResponsesStreamState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + CreatedSent bool + CompletedSent bool + + MessageItemID string + Text strings.Builder + Reasoning strings.Builder + ToolCalls map[int]*ChatToolCall + + FinishReason string + Usage *ResponsesUsage +} + +// NewChatCompletionsToResponsesStreamState returns an initialized stream state. +func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState { + return &ChatCompletionsToResponsesStreamState{ + ResponseID: generateResponsesID(), + Model: model, + Created: time.Now().Unix(), + ToolCalls: make(map[int]*ChatToolCall), + } +} + +// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream +// chunk into zero or more Responses stream events. +func ChatCompletionsChunkToResponsesEvents( + chunk *ChatCompletionsChunk, + state *ChatCompletionsToResponsesStreamState, +) []ResponsesStreamEvent { + if chunk == nil || state == nil { + return nil + } + if chunk.ID != "" { + state.ResponseID = chunk.ID + } + if state.Model == "" && chunk.Model != "" { + state.Model = chunk.Model + } + if chunk.Usage != nil { + state.Usage = ChatUsageToResponsesUsage(chunk.Usage) + } + + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + + for _, choice := range chunk.Choices { + if choice.Delta.Content != nil { + events = append(events, ensureChatToResponsesMessageItem(state)...) + _, _ = state.Text.WriteString(*choice.Delta.Content) + events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Delta: *choice.Delta.Content, + ItemID: state.MessageItemID, + })) + } + if choice.Delta.ReasoningContent != nil { + _, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent) + events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + SummaryIndex: 0, + Delta: *choice.Delta.ReasoningContent, + })) + } + for _, toolCall := range choice.Delta.ToolCalls { + idx := 0 + if toolCall.Index != nil { + idx = *toolCall.Index + } + stored, ok := state.ToolCalls[idx] + if !ok { + copyCall := toolCall + if copyCall.ID == "" { + copyCall.ID = generateItemID() + } + copyCall.Type = "function" + state.ToolCalls[idx] = ©Call + stored = ©Call + events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Item: &ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: stored.ID, + Name: stored.Function.Name, + Status: "in_progress", + }, + })) + } else { + if toolCall.ID != "" { + stored.ID = toolCall.ID + } + if toolCall.Function.Name != "" { + stored.Function.Name = toolCall.Function.Name + } + } + if toolCall.Function.Arguments != "" { + stored.Function.Arguments += toolCall.Function.Arguments + events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Delta: toolCall.Function.Arguments, + CallID: stored.ID, + Name: stored.Function.Name, + })) + } + } + if choice.FinishReason != nil && *choice.FinishReason != "" { + state.FinishReason = *choice.FinishReason + } + } + + return events +} + +// FinalizeChatCompletionsResponsesStream emits terminal Responses events. +func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state == nil || state.CompletedSent { + return nil + } + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + if state.MessageItemID != "" { + events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Text: state.Text.String(), + ItemID: state.MessageItemID, + })) + events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "completed", + }, + })) + } + + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + if state.FinishReason == "length" { + status = "incomplete" + incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + state.CompletedSent = true + events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: state.chatOutput(), + Usage: state.Usage, + IncompleteDetails: incompleteDetails, + }, + })) + return events +} + +func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.CreatedSent { + return nil + } + state.CreatedSent = true + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + })} +} + +func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.MessageItemID != "" { + return nil + } + state.MessageItemID = generateItemID() + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "in_progress", + }, + })} +} + +func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput { + var outputs []ResponsesOutput + if state.Reasoning.Len() > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: state.Reasoning.String(), + }}, + }) + } + if state.MessageItemID != "" || len(state.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: nonEmpty(state.MessageItemID, generateItemID()), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: state.Text.String(), + }}, + Status: "completed", + }) + } + for i := 0; i < len(state.ToolCalls); i++ { + toolCall, ok := state.ToolCalls[i] + if !ok || toolCall == nil { + continue + } + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + return outputs +} + +func chatToResponsesEvent( + state *ChatCompletionsToResponsesStreamState, + eventType string, + template *ResponsesStreamEvent, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func rawString(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + return "" +} + +func rawNestedString(raw json.RawMessage, key string) string { + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + return "" + } + return rawString(obj[key]) +} + +func bytesTrimSpace(raw json.RawMessage) json.RawMessage { + return json.RawMessage(strings.TrimSpace(string(raw))) +} + +func nonEmpty(value, fallback string) string { + if value != "" { + return value + } + return fallback +} diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index dd8fe566..ae3886d6 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -30,6 +30,17 @@ var CodexOfficialClientOriginatorPrefixes = []string{ "codex ", } +// IsBrowserUserAgent 判断 User-Agent 是否来自浏览器(Chrome/Firefox/Safari/Edge/Opera 等)。 +// 所有现代浏览器的 UA 均以 "Mozilla/" 作为前缀,CLI 工具(codex/claude/curl/postman/python-requests 等)不会。 +// 该判定用于避免 Cloudflare 对浏览器型 UA 在 OpenAI 上游接口上触发 JS 质询。 +func IsBrowserUserAgent(userAgent string) bool { + ua := strings.TrimSpace(userAgent) + if ua == "" { + return false + } + return strings.HasPrefix(strings.ToLower(ua), "mozilla/") +} + // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request func IsCodexCLIRequest(userAgent string) bool { ua := normalizeCodexClientHeader(userAgent) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 39283d22..5307389d 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -198,6 +198,19 @@ type APIKeyUsageTrendPoint struct { Tokens int64 `json:"tokens"` } +// APIKeyDailyUsagePoint represents one day of usage for a single API key. +type APIKeyDailyUsagePoint struct { + Date string `json:"date"` + Requests int64 `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheWriteTokens int64 `json:"cache_write_tokens"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 +} + // UserDashboardStats 用户仪表盘统计 type UserDashboardStats struct { // API Key 统计 diff --git a/backend/internal/repository/aes_encryptor_test.go b/backend/internal/repository/aes_encryptor_test.go new file mode 100644 index 00000000..25bff622 --- /dev/null +++ b/backend/internal/repository/aes_encryptor_test.go @@ -0,0 +1,219 @@ +//go:build unit + +package repository + +import ( + "encoding/base64" + "encoding/hex" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ── 测试辅助 ───────────────────────────────────────────────────────────────── + +// aesHexKey 构造一个全填充为 b 的 n 字节密钥并以 hex 编码返回。 +func aesHexKey(n int, b byte) string { + raw := make([]byte, n) + for i := range raw { + raw[i] = b + } + return hex.EncodeToString(raw) +} + +// aesTestCfg 用给定 hex 密钥字符串构造最小 Config。 +func aesTestCfg(keyHex string) *config.Config { + return &config.Config{ + Totp: config.TotpConfig{EncryptionKey: keyHex}, + } +} + +// aesEncryptor 创建一个持有合法 32 字节密钥的加密器,测试失败时立即终止。 +func aesEncryptor(t *testing.T) *AESEncryptor { + t.Helper() + enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x42))) + require.NoError(t, err) + require.NotNil(t, enc) + return enc.(*AESEncryptor) +} + +// ── NewAESEncryptor ────────────────────────────────────────────────────────── + +func TestNewAESEncryptor_ValidKey32Bytes(t *testing.T) { + enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x01))) + require.NoError(t, err) + require.NotNil(t, enc) +} + +// 16 / 24 字节密钥在 AES 体系内合法,但本实现仅接受 AES-256(32 字节)。 +func TestNewAESEncryptor_WrongKeyLength(t *testing.T) { + tests := []struct { + name string + keySize int + }{ + {"16_bytes_AES128", 16}, + {"24_bytes_AES192", 24}, + {"1_byte", 1}, + {"31_bytes", 31}, + {"33_bytes", 33}, + {"64_bytes", 64}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewAESEncryptor(aesTestCfg(aesHexKey(tt.keySize, 0x00))) + require.Error(t, err) + assert.Contains(t, err.Error(), "32 bytes") + }) + } +} + +// "配置缺失"场景:空字符串与非法 hex 编码。 +func TestNewAESEncryptor_MissingOrInvalidConfig(t *testing.T) { + tests := []struct { + name string + keyHex string + wantContain string + }{ + {"empty_key", "", "32 bytes"}, + {"invalid_hex_odd_length", "abcde", "invalid totp encryption key"}, + {"invalid_hex_chars", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", "invalid totp encryption key"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewAESEncryptor(aesTestCfg(tt.keyHex)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantContain) + }) + } +} + +// ── 加解密往返(Roundtrip)─────────────────────────────────────────────────── + +func TestAESEncryptor_RoundTrip(t *testing.T) { + enc := aesEncryptor(t) + + tests := []struct { + name string + plaintext string + }{ + {"ascii", "Hello, Sub2API!"}, + {"chinese_multibyte", "你好,世界!这是多字节 UTF-8 文本。"}, + {"empty_string", ""}, + {"long_string_gt_1KB", strings.Repeat("x", 2048)}, + {"special_chars", "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ct, err := enc.Encrypt(tt.plaintext) + require.NoError(t, err) + require.NotEmpty(t, ct, "密文不应为空(即便明文为空字符串)") + + got, err := enc.Decrypt(ct) + require.NoError(t, err) + assert.Equal(t, tt.plaintext, got) + }) + } +} + +// ── IV/Nonce 随机性 ────────────────────────────────────────────────────────── + +func TestAESEncryptor_Encrypt_NonceRandomness(t *testing.T) { + enc := aesEncryptor(t) + const iterations = 30 + plaintext := "same plaintext for every iteration" + + seen := make(map[string]struct{}, iterations) + for i := 0; i < iterations; i++ { + ct, err := enc.Encrypt(plaintext) + require.NoError(t, err) + seen[ct] = struct{}{} + } + + // 30 次加密相同明文,每次因随机 Nonce 应产生不同密文。 + assert.Len(t, seen, iterations, + "每次加密应因随机 Nonce 产生唯一密文,共 %d 次", iterations) +} + +// ── Decrypt 错误路径 ────────────────────────────────────────────────────────── + +func TestAESDecrypt_InvalidBase64(t *testing.T) { + enc := aesEncryptor(t) + _, err := enc.Decrypt("!!!not-valid-base64!!!") + require.Error(t, err) + assert.Contains(t, err.Error(), "decode base64") +} + +func TestAESDecrypt_TooShort(t *testing.T) { + enc := aesEncryptor(t) + // GCM Nonce 为 12 字节;仅提供 2 字节,必然短于 NonceSize。 + short := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02}) + _, err := enc.Decrypt(short) + require.Error(t, err) + assert.Contains(t, err.Error(), "too short") +} + +func TestAESDecrypt_TamperedCiphertext(t *testing.T) { + enc := aesEncryptor(t) + + ct, err := enc.Encrypt("sensitive payload") + require.NoError(t, err) + + raw, err := base64.StdEncoding.DecodeString(ct) + require.NoError(t, err) + + // Nonce 占前 12 字节;翻转其后第一个字节(密文体)。 + raw[12] ^= 0xFF + _, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw)) + require.Error(t, err, "篡改密文体后解密应失败") +} + +func TestAESDecrypt_TamperedTag(t *testing.T) { + enc := aesEncryptor(t) + + ct, err := enc.Encrypt("sensitive payload") + require.NoError(t, err) + + raw, err := base64.StdEncoding.DecodeString(ct) + require.NoError(t, err) + + // GCM 认证标签占最后 16 字节;翻转最后一个字节。 + raw[len(raw)-1] ^= 0xFF + _, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw)) + require.Error(t, err, "篡改 GCM 标签后解密应失败") +} + +// ── 跨实例(Cross-instance)────────────────────────────────────────────────── + +func TestAESEncryptor_CrossInstance_SameKey_CanDecrypt(t *testing.T) { + keyHex := aesHexKey(32, 0xDE) + + enc1, err := NewAESEncryptor(aesTestCfg(keyHex)) + require.NoError(t, err) + enc2, err := NewAESEncryptor(aesTestCfg(keyHex)) + require.NoError(t, err) + + plaintext := "cross-instance roundtrip" + ct, err := enc1.Encrypt(plaintext) + require.NoError(t, err) + + got, err := enc2.Decrypt(ct) + require.NoError(t, err) + assert.Equal(t, plaintext, got, "相同密钥构造的两个实例应可互相解密") +} + +func TestAESEncryptor_CrossInstance_DifferentKey_CannotDecrypt(t *testing.T) { + enc1, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xAA))) + require.NoError(t, err) + enc2, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xBB))) + require.NoError(t, err) + + ct, err := enc1.Encrypt("secret message") + require.NoError(t, err) + + _, err = enc2.Decrypt(ct) + require.Error(t, err, "不同密钥的实例不应能解密对方的密文") +} diff --git a/backend/internal/repository/allowed_groups_contract_integration_test.go b/backend/internal/repository/allowed_groups_contract_integration_test.go index b0af0d54..5aa95a91 100644 --- a/backend/internal/repository/allowed_groups_contract_integration_test.go +++ b/backend/internal/repository/allowed_groups_contract_integration_test.go @@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te require.NotContains(t, u2After.AllowedGroups, targetGroup.ID) } -func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) { +func TestGroupRepository_DeleteCascade_PreservesApiKeyGroupID(t *testing.T) { ctx := context.Background() tx := testEntTx(t) entClient := tx.Client() @@ -138,8 +138,10 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID) require.Contains(t, uAfter.AllowedGroups, otherGroup.ID) - // API keys bound to the deleted group should have group_id cleared. + // API keys keep their group_id so auth can reject keys bound to a deleted group. keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID) require.NoError(t, err) - require.Nil(t, keyAfter.GroupID) + require.NotNil(t, keyAfter.GroupID) + require.Equal(t, targetGroup.ID, *keyAfter.GroupID) + require.Nil(t, keyAfter.Group) } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 9b6377bc..9c3b2010 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -9,7 +9,6 @@ import ( "strings" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -94,9 +93,13 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group if err != nil { return nil, err } - total, active, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = total - out.ActiveAccountCount = active + counts, err := r.loadAccountCounts(ctx, []int64{out.ID}) + if err == nil { + c := counts[out.ID] + out.AccountCount = c.Total + out.ActiveAccountCount = c.Active + out.RateLimitedAccountCount = c.RateLimited + } return out, nil } @@ -538,15 +541,12 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) { var rateLimited int64 err = scanSingleRow(ctx, r.sql, - `SELECT COUNT(*), - COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true), - COUNT(*) FILTER (WHERE a.status = 'active' AND ( - a.rate_limit_reset_at > NOW() OR - a.overload_until > NOW() OR - a.temp_unschedulable_until > NOW() - )) + fmt.Sprintf(`SELECT + COUNT(*) FILTER (WHERE a.deleted_at IS NULL), + COUNT(*) FILTER (WHERE %s), + COUNT(*) FILTER (WHERE %s) FROM account_groups ag JOIN accounts a ON a.id = ag.account_id - WHERE ag.group_id = $1`, + WHERE ag.group_id = $1`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL), []any{groupID}, &total, &active, &rateLimited) return } @@ -636,28 +636,18 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, } } - // 2. Clear group_id for api keys bound to this group. - // 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。 - // 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。 - if _, err := txClient.APIKey.Update(). - Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()). - ClearGroupID(). - Save(ctx); err != nil { - return nil, err - } - - // 3. Remove the group id from user_allowed_groups join table. + // 2. Remove the group id from user_allowed_groups join table. // Legacy users.allowed_groups 列已弃用,不再同步。 if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil { return nil, err } - // 4. Delete account_groups join rows. + // 3. Delete account_groups join rows. if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil { return nil, err } - // 5. Soft-delete group itself. + // 4. Soft-delete group itself. if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil { return nil, err } @@ -680,6 +670,28 @@ type groupAccountCounts struct { RateLimited int64 } +const ( + // 分组页的"可用"账号数必须与账号仓储的 ListSchedulableByGroupID 过滤口径一致。 + groupAccountAvailableSQL = `a.deleted_at IS NULL + AND a.status = 'active' + AND a.schedulable = true + AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE) + AND (a.rate_limit_reset_at IS NULL OR a.rate_limit_reset_at <= NOW()) + AND (a.overload_until IS NULL OR a.overload_until <= NOW()) + AND (a.temp_unschedulable_until IS NULL OR a.temp_unschedulable_until <= NOW())` + + // 这里沿用历史字段名 RateLimitedAccountCount,但统计的是会让账号暂时退出调度的时间窗口。 + groupAccountTemporarilyLimitedSQL = `a.deleted_at IS NULL + AND a.status = 'active' + AND a.schedulable = true + AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE) + AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )` +) + func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) { counts = make(map[int64]groupAccountCounts, len(groupIDs)) if len(groupIDs) == 0 { @@ -688,18 +700,14 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 rows, err := r.sql.QueryContext( ctx, - `SELECT ag.group_id, - COUNT(*) AS total, - COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active, - COUNT(*) FILTER (WHERE a.status = 'active' AND ( - a.rate_limit_reset_at > NOW() OR - a.overload_until > NOW() OR - a.temp_unschedulable_until > NOW() - )) AS rate_limited + fmt.Sprintf(`SELECT ag.group_id, + COUNT(*) FILTER (WHERE a.deleted_at IS NULL) AS total, + COUNT(*) FILTER (WHERE %s) AS active, + COUNT(*) FILTER (WHERE %s) AS rate_limited FROM account_groups ag JOIN accounts a ON a.id = ag.account_id WHERE ag.group_id = ANY($1) - GROUP BY ag.group_id`, + GROUP BY ag.group_id`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL), pq.Array(groupIDs), ) if err != nil { diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index f91dae43..68183b2b 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -651,6 +651,164 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { s.Require().Zero(count) } +// TestListWithFilters_ActiveAccountCount_LessThanTotal 验证 ActiveAccountCount 正确区分可用与不可用账号。 +// 当分组内存在 disabled 或 schedulable=false 的账号时,ActiveAccountCount 必须小于 AccountCount, +// 且与 GetAccountCount 返回的 active 值一致。 +func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() { + g := &service.Group{ + Name: "g-mixed-status", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g)) + + insertAccount := func(name, status string, schedulable bool) int64 { + var id int64 + s.Require().NoError(scanSingleRow( + s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, status, schedulable) VALUES ($1, $2, $3, $4, $5) RETURNING id", + []any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status, schedulable}, + &id, + )) + return id + } + link := func(accountID int64, priority int) { + _, err := s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + accountID, g.ID, priority) + s.Require().NoError(err) + } + + // account 1: active + schedulable → counts toward both total and active + link(insertAccount("acc-active-sched", service.StatusActive, true), 1) + // account 2: disabled → counts toward total only + link(insertAccount("acc-disabled", service.StatusDisabled, true), 2) + // account 3: active + not schedulable → counts toward total only + link(insertAccount("acc-unschedulable", service.StatusActive, false), 3) + + // --- ListWithFilters path --- + isExclusive := false + groups, _, err := s.repo.ListWithFilters(s.ctx, + pagination.PaginationParams{Page: 1, PageSize: 100}, + service.PlatformAnthropic, service.StatusActive, "", &isExclusive) + s.Require().NoError(err) + + var found *service.Group + for i := range groups { + if groups[i].ID == g.ID { + found = &groups[i] + break + } + } + s.Require().NotNil(found, "created group must appear in ListWithFilters result") + s.Assert().Equal(int64(3), found.AccountCount, "AccountCount must count all 3 accounts") + s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must count only the active+schedulable account") + + // --- GetAccountCount must return identical values --- + total, active, err := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount") + s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount") +} + +// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。 +// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池, +// 因此 ActiveAccountCount 必须与真实调度查询口径一致。 +func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() { + g := &service.Group{ + Name: "g-rate-limited", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g)) + + var normalID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id", + []any{"acc-normal", service.PlatformAnthropic, service.AccountTypeOAuth}, + &normalID)) + + var rateLimitedID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, rate_limit_reset_at) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth}, + &rateLimitedID)) + + var overloadedID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth}, + &overloadedID)) + + var tempUnschedulableID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id", + []any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth}, + &tempUnschedulableID)) + + var expiredID int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id", + []any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth}, + &expiredID)) + + _, err := s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + normalID, g.ID, 1) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + rateLimitedID, g.ID, 2) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + overloadedID, g.ID, 3) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + tempUnschedulableID, g.ID, 4) + s.Require().NoError(err) + _, err = s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + expiredID, g.ID, 5) + s.Require().NoError(err) + + isExclusive := false + groups, _, err := s.repo.ListWithFilters(s.ctx, + pagination.PaginationParams{Page: 1, PageSize: 100}, + service.PlatformAnthropic, service.StatusActive, "", &isExclusive) + s.Require().NoError(err) + + var found *service.Group + for i := range groups { + if groups[i].ID == g.ID { + found = &groups[i] + break + } + } + s.Require().NotNil(found, "created group must appear in ListWithFilters result") + s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts") + s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts") + s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts") + + total, active, err := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount") + s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount") + + detail, err := s.repo.GetByID(s.ctx, g.ID) + s.Require().NoError(err) + s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters") + s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters") + s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters") +} + // --- DeleteAccountGroupsByGroupID --- func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { diff --git a/backend/internal/repository/group_repo_sort_integration_test.go b/backend/internal/repository/group_repo_sort_integration_test.go index 85b2efcc..39e1ec78 100644 --- a/backend/internal/repository/group_repo_sort_integration_test.go +++ b/backend/internal/repository/group_repo_sort_integration_test.go @@ -7,6 +7,67 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) +// TestListWithAccountCountSort_AttachesActiveCount 验证通过 account_count 排序时, +// ActiveAccountCount 与 AccountCount 都被正确附加到返回结果中, +// 且排序基于 total 账号数而非 active 账号数。 +func (s *GroupRepoSuite) TestListWithAccountCountSort_AttachesActiveCount() { + // Group A: 2 total, 1 active (1 disabled account) + gA := &service.Group{Name: "sort-count-a", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard} + // Group B: 1 total, 1 active + gB := &service.Group{Name: "sort-count-b", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard} + s.Require().NoError(s.repo.Create(s.ctx, gA)) + s.Require().NoError(s.repo.Create(s.ctx, gB)) + + insertAccount := func(name, status string) int64 { + var id int64 + s.Require().NoError(scanSingleRow(s.ctx, s.tx, + "INSERT INTO accounts (name, platform, type, status) VALUES ($1, $2, $3, $4) RETURNING id", + []any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status}, + &id)) + return id + } + link := func(accountID, groupID int64, priority int) { + _, err := s.tx.ExecContext(s.ctx, + "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", + accountID, groupID, priority) + s.Require().NoError(err) + } + + // gA: 1 active + 1 disabled → total=2, active=1 + link(insertAccount("sa-active", service.StatusActive), gA.ID, 1) + link(insertAccount("sa-disabled", service.StatusDisabled), gA.ID, 2) + // gB: 1 active → total=1, active=1 + link(insertAccount("sb-active", service.StatusActive), gB.ID, 1) + + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, PageSize: 100, SortBy: "account_count", SortOrder: "desc", + }, service.PlatformAnthropic, service.StatusActive, "", nil) + s.Require().NoError(err) + + byID := make(map[int64]service.Group, len(groups)) + for _, g := range groups { + byID[g.ID] = g + } + + s.Require().Contains(byID, gA.ID, "gA must appear in results") + s.Require().Contains(byID, gB.ID, "gB must appear in results") + + cA := byID[gA.ID] + s.Assert().Equal(int64(2), cA.AccountCount, "gA AccountCount must be 2") + s.Assert().Equal(int64(1), cA.ActiveAccountCount, "gA ActiveAccountCount must be 1") + + cB := byID[gB.ID] + s.Assert().Equal(int64(1), cB.AccountCount, "gB AccountCount must be 1") + s.Assert().Equal(int64(1), cB.ActiveAccountCount, "gB ActiveAccountCount must be 1") + + // Sort is by total (not active): gA (total=2) must rank higher than gB (total=1) in desc order + indexByID := make(map[int64]int, len(groups)) + for i, g := range groups { + indexByID[g.ID] = i + } + s.Assert().Less(indexByID[gA.ID], indexByID[gB.ID], "gA (total=2) must rank above gB (total=1) with account_count desc") +} + func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() { g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20} g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 3770d585..65757b62 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -833,6 +833,7 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": true, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": true, + "openai_codex_user_agent": "", "openai_fast_policy_settings": { "rules": [] }, @@ -1058,6 +1059,7 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": false, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": false, + "openai_codex_user_agent": "", "openai_fast_policy_settings": { "rules": [] }, diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index c15f534e..ee439cda 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -109,6 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti AbortWithError(c, 401, "USER_INACTIVE", "User account is not active") return } + if abortIfAPIKeyGroupUnavailable(c, apiKey) { + return + } // ── 4. SimpleMode → early return ───────────────────────────── @@ -251,3 +254,26 @@ func setGroupContext(c *gin.Context, group *service.Group) { ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group) c.Request = c.Request.WithContext(ctx) } + +func abortIfAPIKeyGroupUnavailable(c *gin.Context, apiKey *service.APIKey) bool { + code, message, ok := validateAPIKeyGroupAvailable(apiKey) + if ok { + return false + } + AbortWithError(c, 403, code, message) + return true +} + +func validateAPIKeyGroupAvailable(apiKey *service.APIKey) (string, string, bool) { + if apiKey == nil || apiKey.GroupID == nil { + return "", "", true + } + group := apiKey.Group + if group == nil || strings.EqualFold(group.Status, "deleted") { + return "GROUP_DELETED", "API Key 所属分组已删除", false + } + if !group.IsActive() { + return "GROUP_DISABLED", "API Key 所属分组已停用", false + } + return "", "", true +} diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 84d93edc..3ed71f71 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -54,6 +54,10 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 401, "User account is not active") return } + if _, message, ok := validateAPIKeyGroupAvailable(apiKey); !ok { + abortWithGoogleError(c, 403, message) + return + } // 简易模式:跳过余额和订阅检查 if cfg.RunMode == config.RunModeSimple { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index d6760d8d..a00f70c7 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -300,6 +300,104 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(101) + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + + tests := []struct { + name string + group *service.Group + wantStatus int + wantCode string + }{ + { + name: "active group passes", + group: &service.Group{ + ID: groupID, + Name: "active", + Status: service.StatusActive, + Platform: service.PlatformAnthropic, + Hydrated: true, + }, + wantStatus: http.StatusOK, + }, + { + name: "disabled group is forbidden", + group: &service.Group{ + ID: groupID, + Name: "disabled", + Status: service.StatusDisabled, + Platform: service.PlatformAnthropic, + Hydrated: true, + }, + wantStatus: http.StatusForbidden, + wantCode: "GROUP_DISABLED", + }, + { + name: "deleted status group is forbidden", + group: &service.Group{ + ID: groupID, + Name: "deleted", + Status: "deleted", + Platform: service.PlatformAnthropic, + Hydrated: true, + }, + wantStatus: http.StatusForbidden, + wantCode: "GROUP_DELETED", + }, + { + name: "missing group edge is forbidden", + group: nil, + wantStatus: http.StatusForbidden, + wantCode: "GROUP_DELETED", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + GroupID: &groupID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: tt.group, + } + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, tt.wantStatus, w.Code) + if tt.wantCode != "" { + require.Contains(t, w.Body.String(), tt.wantCode) + } + }) + } +} + func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index b8af9cc5..65eed1ee 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -422,6 +422,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { adminSettings.PUT("", h.Admin.Setting.UpdateSettings) adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection) adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) + adminSettings.GET("/email-templates", h.Admin.Setting.ListEmailTemplates) + adminSettings.POST("/email-template-preview", h.Admin.Setting.PreviewEmailTemplate) + adminSettings.GET("/email-templates/:event/:locale", h.Admin.Setting.GetEmailTemplate) + adminSettings.PUT("/email-templates/:event/:locale", h.Admin.Setting.UpdateEmailTemplate) + adminSettings.POST("/email-templates/:event/:locale/restore-official", h.Admin.Setting.RestoreOfficialEmailTemplate) // Admin API Key 管理 adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 19d0fd2a..2c44a2b3 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -214,6 +214,7 @@ func RegisterAuthRoutes( settings := v1.Group("/settings") { settings.GET("/public", h.Setting.GetPublicSettings) + settings.GET("/email-unsubscribe", h.Setting.UnsubscribeNotificationEmail) } // 需要认证的当前用户信息 diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index 9976954c..e79d3ee3 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -31,6 +31,7 @@ func RegisterUserRoutes( user.POST("/account-bindings/email", h.User.BindEmailIdentity) user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) + user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index a9492a1d..c1da92d1 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -244,6 +244,21 @@ func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSor return nil } +type deleteGroupAPIKeyRepoStub struct { + apiKeyRepoStubForGroupUpdate + keys []string + listErr error + listGroupIDs []int64 +} + +func (s *deleteGroupAPIKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + s.listGroupIDs = append(s.listGroupIDs, groupID) + if s.listErr != nil { + return nil, s.listErr + } + return s.keys, nil +} + type proxyRepoStub struct { deleteErr error countErr error @@ -500,6 +515,23 @@ func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) { }, calls) } +func TestAdminService_DeleteGroup_InvalidatesAuthCacheForBoundKeys(t *testing.T) { + repo := &groupRepoStub{} + apiKeyRepo := &deleteGroupAPIKeyRepoStub{keys: []string{"k1", "k2"}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + groupRepo: repo, + apiKeyRepo: apiKeyRepo, + authCacheInvalidator: invalidator, + } + + err := svc.DeleteGroup(context.Background(), 5) + require.NoError(t, err) + require.Equal(t, []int64{5}, repo.deleteCalls) + require.Equal(t, []int64{5}, apiKeyRepo.listGroupIDs) + require.Equal(t, []string{"k1", "k2"}, invalidator.keys) +} + func TestAdminService_DeleteGroup_NotFound(t *testing.T) { repo := &groupRepoStub{deleteErr: ErrGroupNotFound} svc := &adminServiceImpl{groupRepo: repo} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 877888b1..c752ce28 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs +const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks type apiKeyAuthCacheConfig struct { l1Size int diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go index 78f1185d..84f61d78 100644 --- a/backend/internal/service/auth_email_binding.go +++ b/backend/internal/service/auth_email_binding.go @@ -94,7 +94,7 @@ func (s *AuthService) BindEmailIdentity( } // SendEmailIdentityBindCode sends a verification code for authenticated email binding flows. -func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error { +func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string, locale ...string) error { if s == nil { return ErrServiceUnavailable } @@ -128,7 +128,7 @@ func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int6 if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } - return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName) + return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName, firstEmailLocale(locale)) } func normalizeEmailForIdentityBinding(email string) (string, error) { diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index 3478fda5..cf0be652 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -28,7 +28,7 @@ func normalizeOAuthSignupSource(signupSource string) string { // SendPendingOAuthVerifyCode sends a local verification code for pending OAuth // account-creation flows without relying on the public registration gate. -func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) { +func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) { email = strings.TrimSpace(strings.ToLower(email)) if email == "" { return nil, ErrEmailVerifyRequired @@ -47,7 +47,7 @@ func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email stri if s.settingService != nil { siteName = s.settingService.GetSiteName(ctx) } - if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil { + if err := s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)); err != nil { return nil, err } return &SendVerifyCodeResult{ diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index ce2b3fa3..4e5b7b94 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -273,7 +273,7 @@ type SendVerifyCodeResult struct { } // SendVerifyCode 发送邮箱验证码(同步方式) -func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { +func (s *AuthService) SendVerifyCode(ctx context.Context, email string, locale ...string) error { // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled @@ -307,11 +307,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { siteName = s.settingService.GetSiteName(ctx) } - return s.emailService.SendVerifyCode(ctx, email, siteName) + return s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)) } // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 -func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { +func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) { logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册(默认关闭) @@ -352,7 +352,7 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S // 异步发送 logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email) - if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { + if err := s.emailQueueService.EnqueueVerifyCode(email, siteName, firstEmailLocale(locale)); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } @@ -1251,7 +1251,7 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB // RequestPasswordReset 请求密码重置(同步发送) // Security: Returns the same response regardless of whether the email exists (prevent user enumeration) -func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error { +func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string, locale ...string) error { if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } @@ -1264,7 +1264,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB return nil // Silent success to prevent enumeration } - if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { + if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err) return nil // Silent success to prevent enumeration } @@ -1275,7 +1275,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB // RequestPasswordResetAsync 异步请求密码重置(队列发送) // Security: Returns the same response regardless of whether the email exists (prevent user enumeration) -func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error { +func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string, locale ...string) error { if !s.IsPasswordResetEnabled(ctx) { return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") } @@ -1288,7 +1288,7 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron return nil // Silent success to prevent enumeration } - if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { + if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL, firstEmailLocale(locale)); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err) return nil // Silent success to prevent enumeration } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 5b7e413a..26803275 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -39,9 +39,10 @@ type AccountQuotaReader interface { // BalanceNotifyService handles balance and quota threshold notifications. type BalanceNotifyService struct { - emailService *EmailService - settingRepo SettingRepository - accountRepo AccountQuotaReader + emailService *EmailService + settingRepo SettingRepository + accountRepo AccountQuotaReader + notificationEmailService *NotificationEmailService } // NewBalanceNotifyService creates a new BalanceNotifyService. @@ -53,6 +54,10 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo } } +func (s *BalanceNotifyService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { + s.notificationEmailService = notificationEmailService +} + // resolveBalanceThreshold returns the effective balance threshold. // For percentage type, it computes threshold = totalRecharged * percentage / 100. func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 { @@ -125,7 +130,7 @@ func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user slog.Error("panic in balance notification", "recover", r) } }() - s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL) + s.sendBalanceLowEmails(recipients, user.ID, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL) }() } @@ -342,11 +347,44 @@ func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body str } // sendBalanceLowEmails sends balance low notification to all recipients. -func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) { +func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userID int64, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) { displayName := userName if displayName == "" { displayName = userEmail } + if s.notificationEmailService != nil { + fallbackRecipients := make([]string, 0, len(recipients)) + for _, to := range recipients { + ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventBalanceLow, + RecipientEmail: to, + RecipientName: displayName, + UserID: userID, + SourceType: "balance_low", + SourceID: firstNonEmpty(strconv.FormatInt(userID, 10), userEmail), + ReminderKey: time.Now().UTC().Format("2006-01-02"), + Variables: map[string]string{ + "current_balance": fmt.Sprintf("%.2f", balance), + "threshold": fmt.Sprintf("%.2f", threshold), + "recharge_url": rechargeURL, + }, + }) + cancel() + if err != nil { + if shouldFallbackNotificationEmail(err) { + slog.Warn("template balance low notification failed; falling back to built-in body", "to", to, "err", err.Error()) + fallbackRecipients = append(fallbackRecipients, to) + } else { + slog.Warn("template balance low notification delivery failed; not sending fallback to avoid duplicates", "to", to, "err", err.Error()) + } + } + } + if len(fallbackRecipients) == 0 { + return + } + recipients = fallbackRecipients + } subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName)) body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL) s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance) @@ -369,6 +407,44 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun remaining = 0 } + if s.notificationEmailService != nil { + fallbackRecipients := make([]string, 0, len(adminEmails)) + for _, to := range adminEmails { + ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventAccountQuotaAlert, + RecipientEmail: to, + RecipientName: emailRecipientName(to), + SourceType: "account_quota", + SourceID: fmt.Sprintf("%d-%s", accountID, dim.name), + ReminderKey: time.Now().UTC().Format("2006-01-02"), + Variables: map[string]string{ + "account_id": strconv.FormatInt(accountID, 10), + "account_name": accountName, + "platform": platform, + "quota_dimension": dimLabel, + "quota_used": fmt.Sprintf("%.2f", used), + "quota_limit": fmt.Sprintf("%.2f", dim.limit), + "quota_remaining": fmt.Sprintf("%.2f", remaining), + "quota_threshold": thresholdDisplay, + }, + }) + cancel() + if err != nil { + if shouldFallbackNotificationEmail(err) { + slog.Warn("template account quota alert failed; falling back to built-in body", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error()) + fallbackRecipients = append(fallbackRecipients, to) + } else { + slog.Warn("template account quota alert delivery failed; not sending fallback to avoid duplicates", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error()) + } + } + } + if len(fallbackRecipients) == 0 { + return + } + adminEmails = fallbackRecipients + } + subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName)) body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName)) s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name) diff --git a/backend/internal/service/claude_code_session_id.go b/backend/internal/service/claude_code_session_id.go index f087c004..e000108e 100644 --- a/backend/internal/service/claude_code_session_id.go +++ b/backend/internal/service/claude_code_session_id.go @@ -15,16 +15,19 @@ import ( // // 行为按 tokenType / mimicClaudeCode 分两条路径: // -// OAuth mimic 路径 (tokenType == "oauth" && mimicClaudeCode): -// 1. body 中 metadata.user_id 派生的 SessionID 是合法 UUID → canonicalize 写入 -// 2. 请求 header 中已有合法 UUID → canonicalize 保留 -// 3. 否则 → 兜底生成 UUID +// OAuth 路径 (tokenType == "oauth"): +// OAuth 账号本身就是真实 Claude Code 客户端的凭证,可以信任 body 中的 +// metadata.user_id 派生 session id。 +// 1. metadata.user_id 派生 SessionID 是合法 UUID → canonical 写入 +// 2. header 已有合法 UUID → canonical 保留 +// 3. mimicClaudeCode == true → 兜底生成新 UUID +// (mimicClaudeCode == false 且无 metadata 时不强制注入) // -// API key 透传 / 非 mimic 路径: -// - 不从 body 合成 header(避免污染客户端原始语义) -// - 但若客户端在 header 中传入了 X-Claude-Code-Session-Id: -// 合法 UUID → canonicalize 保留 -// 非法值 → 删除(不向上游转发恶意值,符合 UUID 校验承诺) +// API key 透传路径 (tokenType != "oauth"): +// - 不从 body metadata 派生 header(避免污染客户端原始语义) +// - 若客户端在 header 中传入 X-Claude-Code-Session-Id: +// 合法 UUID → canonical 保留 +// 非法值 → 删除(不向上游转发恶意值) // - 不兜底生成 // // 安全说明:metadata.user_id 由客户端控制,ParseMetadataUserID 的正则仅约束字符集, @@ -37,10 +40,10 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string, req.Header = make(http.Header) } - isOAuthMimic := tokenType == "oauth" && mimicClaudeCode + isOAuth := tokenType == "oauth" - // OAuth mimic 路径:从 metadata 派生(仅在 mimic 场景写 header)。 - if isOAuthMimic { + // OAuth 路径:从 metadata 派生(OAuth 凭证可信任)。 + if isOAuth { if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { if parsed := ParseMetadataUserID(uid); parsed != nil { if id, err := uuid.Parse(parsed.SessionID); err == nil { @@ -65,9 +68,9 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string, req.Header.Del("X-Claude-Code-Session-Id") } - // OAuth mimic 兜底生成(仅 mimic 场景;API key 不污染)。 + // OAuth mimic 兜底生成(仅 mimic 场景;API key/非 mimic 不污染)。 // uuid.NewString() 走 crypto/rand。 - if isOAuthMimic { + if isOAuth && mimicClaudeCode { setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", uuid.NewString()) } } diff --git a/backend/internal/service/claude_code_session_id_test.go b/backend/internal/service/claude_code_session_id_test.go index d5c9023b..17ec1138 100644 --- a/backend/internal/service/claude_code_session_id_test.go +++ b/backend/internal/service/claude_code_session_id_test.go @@ -136,15 +136,17 @@ func TestEnsureClaudeCodeSessionID_APIKeyIgnoresMetadata(t *testing.T) { } } -// OAuth 但非 mimic 模式也不应该从 metadata 派生 header。 -func TestEnsureClaudeCodeSessionID_OAuthNonMimicIgnoresMetadata(t *testing.T) { +// OAuth 路径即使 mimic=false 也应该从 metadata 派生 header: +// OAuth 凭证本身就是 Claude Code 类型账号,metadata.user_id 可信任。 +// 这与 API key 路径不同(API key 是任意第三方调用方)。 +func TestEnsureClaudeCodeSessionID_OAuthNonMimicDerivesFromMetadata(t *testing.T) { req := newReq(t) body := []byte(`{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"` + testValidUUID + `\"}"}}`) ensureClaudeCodeSessionID(req, body, "oauth", false) got := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id") - if got != "" { - t.Fatalf("Non-mimic OAuth must NOT derive session-id from metadata, got %q", got) + if got != testValidUUID { + t.Fatalf("OAuth must derive session-id from metadata regardless of mimic flag, got %q want %q", got, testValidUUID) } } diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go index 6a7c9904..2d066298 100644 --- a/backend/internal/service/content_moderation.go +++ b/backend/internal/service/content_moderation.go @@ -1463,6 +1463,24 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { siteName := s.siteName(ctx) + if s.emailService.notificationEmailService != nil { + if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventContentModerationViolation, + RecipientEmail: log.UserEmail, + RecipientName: emailRecipientName(log.UserEmail), + UserID: contentModerationEmailUserID(log), + SourceType: "content_moderation", + SourceID: contentModerationEmailSourceID(log), + Variables: contentModerationEmailVariables(log, cfg), + }); err == nil { + return nil + } else { + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("template content moderation violation email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error()) + } + } subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName)) body := buildContentModerationViolationEmailBody(siteName, log, cfg) return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) @@ -1470,11 +1488,71 @@ func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg * func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { siteName := s.siteName(ctx) + if s.emailService.notificationEmailService != nil { + if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventContentModerationDisabled, + RecipientEmail: log.UserEmail, + RecipientName: emailRecipientName(log.UserEmail), + UserID: contentModerationEmailUserID(log), + SourceType: "content_moderation", + SourceID: contentModerationEmailSourceID(log), + Variables: contentModerationEmailVariables(log, cfg), + }); err == nil { + return nil + } else { + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("template content moderation disabled email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error()) + } + } subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName)) body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg) return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) } +func contentModerationEmailUserID(log *ContentModerationLog) int64 { + if log == nil || log.UserID == nil { + return 0 + } + return *log.UserID +} + +func contentModerationEmailSourceID(log *ContentModerationLog) string { + if log == nil || log.ID <= 0 { + return "" + } + return fmt.Sprintf("%d", log.ID) +} + +func contentModerationEmailVariables(log *ContentModerationLog, cfg *ContentModerationConfig) map[string]string { + variables := map[string]string{ + "triggered_at": time.Now().UTC().Format(time.RFC3339), + "group_name": "-", + "moderation_category": "-", + "moderation_score": "0.000", + "violation_count": "0", + "ban_threshold": "0", + } + if log != nil { + if !log.CreatedAt.IsZero() { + variables["triggered_at"] = log.CreatedAt.UTC().Format(time.RFC3339) + } + if strings.TrimSpace(log.GroupName) != "" { + variables["group_name"] = strings.TrimSpace(log.GroupName) + } + if strings.TrimSpace(log.HighestCategory) != "" { + variables["moderation_category"] = strings.TrimSpace(log.HighestCategory) + } + variables["moderation_score"] = fmt.Sprintf("%.3f", log.HighestScore) + variables["violation_count"] = fmt.Sprintf("%d", log.ViolationCount) + } + if cfg != nil { + variables["ban_threshold"] = fmt.Sprintf("%d", cfg.BanThreshold) + } + return variables +} + func (s *ContentModerationService) siteName(ctx context.Context) string { if s == nil || s.settingRepo == nil { return "Sub2API" diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 4f2d40d8..8d591086 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -401,6 +401,10 @@ const ( SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control" // SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值) SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version" + // SettingKeyOpenAICodexUserAgent OpenAI Codex 完整 User-Agent(空值使用内置默认) + // 当客户端 UA 被识别为浏览器(Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值, + // 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。 + SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent" // Balance Low Notification SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 diff --git a/backend/internal/service/email_queue_service.go b/backend/internal/service/email_queue_service.go index d8f0a518..a933e6bb 100644 --- a/backend/internal/service/email_queue_service.go +++ b/backend/internal/service/email_queue_service.go @@ -21,6 +21,7 @@ type EmailTask struct { SiteName string TaskType string // "verify_code" or "password_reset" ResetURL string // Only used for password_reset task type + Locale string // Optional Accept-Language locale hint } // EmailQueueService 异步邮件队列服务 @@ -82,13 +83,13 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { switch task.TaskType { case TaskTypeVerifyCode: - if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { + if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName, task.Locale); err != nil { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) } else { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) } case TaskTypePasswordReset: - if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil { + if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL, task.Locale); err != nil { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) } else { logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) @@ -99,11 +100,12 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { } // EnqueueVerifyCode 将验证码发送任务加入队列 -func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { +func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string, locale ...string) error { task := EmailTask{ Email: email, SiteName: siteName, TaskType: TaskTypeVerifyCode, + Locale: firstEmailLocale(locale), } select { @@ -116,12 +118,13 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { } // EnqueuePasswordReset 将密码重置邮件任务加入队列 -func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error { +func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string, locale ...string) error { task := EmailTask{ Email: email, SiteName: siteName, TaskType: TaskTypePasswordReset, ResetURL: resetURL, + Locale: firstEmailLocale(locale), } select { diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 9a03ea30..2cf42d73 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -94,8 +94,9 @@ type SMTPConfig struct { // EmailService 邮件服务 type EmailService struct { - settingRepo SettingRepository - cache EmailCache + settingRepo SettingRepository + cache EmailCache + notificationEmailService *NotificationEmailService } // NewEmailService 创建邮件服务实例 @@ -106,6 +107,28 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ } } +func (s *EmailService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { + s.notificationEmailService = notificationEmailService +} + +func firstEmailLocale(locales []string) string { + if len(locales) == 0 { + return "" + } + return strings.TrimSpace(locales[0]) +} + +func emailRecipientName(email string) string { + trimmed := strings.TrimSpace(email) + if trimmed == "" { + return "" + } + if at := strings.Index(trimmed, "@"); at > 0 { + return trimmed[:at] + } + return trimmed +} + // GetSMTPConfig 从数据库获取SMTP配置 func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { keys := []string{ @@ -301,7 +324,7 @@ func (s *EmailService) GenerateVerifyCode() (string, error) { } // SendVerifyCode 发送验证码邮件 -func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error { +func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string, locale ...string) error { // 检查是否在冷却期内 existing, err := s.cache.GetVerificationCode(ctx, email) if err == nil && existing != nil { @@ -327,6 +350,26 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin return fmt.Errorf("save verify code: %w", err) } + if s.notificationEmailService != nil { + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventAuthVerifyCode, + Locale: firstEmailLocale(locale), + RecipientEmail: email, + RecipientName: emailRecipientName(email), + Variables: map[string]string{ + "verification_code": code, + "expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)), + }, + }) + if err == nil { + return nil + } + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("failed to send templated verification email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err) + } + // 构建邮件内容 subject := fmt.Sprintf("[%s] Email Verification Code", siteName) body := s.buildVerifyCodeEmailBody(code, siteName) @@ -469,7 +512,7 @@ func (s *EmailService) GeneratePasswordResetToken() (string, error) { } // SendPasswordResetEmail sends a password reset email with a reset link -func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error { +func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string, locale ...string) error { var token string var needSaveToken bool @@ -502,6 +545,26 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa // Build full reset URL with URL-encoded token and email fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token)) + if s.notificationEmailService != nil { + err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventAuthPasswordReset, + Locale: firstEmailLocale(locale), + RecipientEmail: email, + RecipientName: emailRecipientName(email), + Variables: map[string]string{ + "reset_url": fullResetURL, + "expires_in_minutes": strconv.Itoa(int(passwordResetTokenTTL / time.Minute)), + }, + }) + if err == nil { + return nil + } + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("failed to send templated password reset email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err) + } + // Build email content subject := fmt.Sprintf("[%s] 密码重置请求", siteName) body := s.buildPasswordResetEmailBody(fullResetURL, siteName) @@ -516,7 +579,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa // SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker) // This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing -func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { +func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string, locale ...string) error { // Check email cooldown to prevent email bombing if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { slog.Info("password reset email skipped due to cooldown", "email", email) @@ -524,7 +587,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e } // Send email using core method - if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { + if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil { return err } diff --git a/backend/internal/service/notification_email_service.go b/backend/internal/service/notification_email_service.go new file mode 100644 index 00000000..d21363b6 --- /dev/null +++ b/backend/internal/service/notification_email_service.go @@ -0,0 +1,1347 @@ +package service + +import ( + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "html" + "log/slog" + "net/url" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + NotificationEmailEventAuthVerifyCode = "auth.verify_code" + NotificationEmailEventAuthPasswordReset = "auth.password_reset" + NotificationEmailEventNotificationEmailVerifyCode = "notification_email.verify_code" + NotificationEmailEventSubscriptionPurchaseSuccess = "subscription.purchase_success" + NotificationEmailEventSubscriptionExpiryReminder = "subscription.expiry_reminder" + NotificationEmailEventBalanceLow = "balance.low" + NotificationEmailEventBalanceRechargeSuccess = "balance.recharge_success" + NotificationEmailEventAccountQuotaAlert = "account.quota_alert" + NotificationEmailEventContentModerationViolation = "content_moderation.violation_notice" + NotificationEmailEventContentModerationDisabled = "content_moderation.account_disabled" + NotificationEmailEventOpsAlert = "ops.alert" + NotificationEmailEventOpsScheduledReport = "ops.scheduled_report" + + notificationEmailTemplateKeyPrefix = "notification_email_template:" + notificationEmailPreferenceKeyPrefix = "notification_email_preference:" + notificationEmailDeliveryKeyPrefix = "notification_email_delivery:" + notificationEmailLocaleUserKeyPrefix = "notification_email_locale:user:" + notificationEmailLocaleEmailKeyPrefix = "notification_email_locale:email:" + notificationEmailUnsubscribeSecretKey = "notification_email_unsubscribe_secret" + notificationEmailDefaultLocale = "en" + notificationEmailLocaleChinese = "zh" + notificationEmailMaxSubjectLength = 200 + notificationEmailMaxHTMLLength = 30000 + notificationEmailUnsubscribeTTL = 365 * 24 * time.Hour +) + +var ( + notificationEmailPlaceholderPattern = regexp.MustCompile(`{{\s*([a-zA-Z][a-zA-Z0-9_]*)\s*}}`) + notificationEmailLocales = []string{notificationEmailDefaultLocale, notificationEmailLocaleChinese} + notificationEmailCommonPlaceholders = []string{"site_name", "recipient_name", "recipient_email"} +) + +type NotificationEmailService struct { + settingRepo SettingRepository + emailService *EmailService +} + +type NotificationEmailEventInfo struct { + Event string `json:"event"` + Label string `json:"label"` + Description string `json:"description"` + Category string `json:"category"` + Optional bool `json:"optional"` + Placeholders []string `json:"placeholders"` +} + +type NotificationEmailTemplate struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + IsCustom bool `json:"is_custom"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + Placeholders []string `json:"placeholders"` +} + +type NotificationEmailPreview struct { + Subject string `json:"subject"` + HTML string `json:"html"` +} + +type NotificationEmailPreviewInput struct { + Event string `json:"event"` + Locale string `json:"locale"` + Subject string `json:"subject"` + HTML string `json:"html"` + Variables map[string]string `json:"variables,omitempty"` +} + +type NotificationEmailSendInput struct { + Event string + Locale string + RecipientEmail string + RecipientName string + UserID int64 + SourceType string + SourceID string + ReminderKey string + Variables map[string]string + RawHTMLVariables map[string]string +} + +type NotificationEmailUnsubscribeResult struct { + Event string `json:"event"` + Email string `json:"email"` + Done bool `json:"done"` +} + +type notificationEmailStoredTemplate struct { + Subject string `json:"subject"` + HTML string `json:"html"` + UpdatedAt time.Time `json:"updated_at"` +} + +type notificationEmailOfficialTemplate struct { + Subject string + HTML string +} + +type notificationEmailTemplateError struct { + Err error +} + +func (e notificationEmailTemplateError) Error() string { + return e.Err.Error() +} + +func (e notificationEmailTemplateError) Unwrap() error { + return e.Err +} + +type notificationEmailConfigError struct { + Err error +} + +func (e notificationEmailConfigError) Error() string { + return e.Err.Error() +} + +func (e notificationEmailConfigError) Unwrap() error { + return e.Err +} + +type notificationEmailDeliveryError struct { + Err error +} + +func (e notificationEmailDeliveryError) Error() string { + return e.Err.Error() +} + +func (e notificationEmailDeliveryError) Unwrap() error { + return e.Err +} + +type notificationEmailUnsubscribeClaims struct { + Email string `json:"email"` + Event string `json:"event"` + Exp int64 `json:"exp"` +} + +func NewNotificationEmailService(settingRepo SettingRepository, emailService *EmailService) *NotificationEmailService { + svc := &NotificationEmailService{settingRepo: settingRepo, emailService: emailService} + if emailService != nil { + emailService.SetNotificationEmailService(svc) + } + return svc +} + +func notificationEmailTemplateErr(err error) error { + if err == nil { + return nil + } + return notificationEmailTemplateError{Err: err} +} + +func notificationEmailConfigErr(err error) error { + if err == nil { + return nil + } + return notificationEmailConfigError{Err: err} +} + +func notificationEmailDeliveryErr(err error) error { + if err == nil { + return nil + } + return notificationEmailDeliveryError{Err: err} +} + +func shouldFallbackNotificationEmail(err error) bool { + if err == nil { + return false + } + var templateErr notificationEmailTemplateError + if errors.As(err, &templateErr) { + return true + } + var configErr notificationEmailConfigError + return errors.As(err, &configErr) +} + +func isNotificationEmailDeliveryError(err error) bool { + var deliveryErr notificationEmailDeliveryError + return errors.As(err, &deliveryErr) +} + +func (s *NotificationEmailService) ListEventInfos() []NotificationEmailEventInfo { + infos := make([]NotificationEmailEventInfo, 0, len(notificationEmailEventDefinitions)) + for _, event := range notificationEmailEventOrder { + info := notificationEmailEventDefinitions[event] + info.Placeholders = append([]string(nil), info.Placeholders...) + infos = append(infos, info) + } + return infos +} + +func (s *NotificationEmailService) SupportedLocales() []string { + return append([]string(nil), notificationEmailLocales...) +} + +func (s *NotificationEmailService) ListTemplates(ctx context.Context) ([]NotificationEmailTemplate, error) { + items := make([]NotificationEmailTemplate, 0, len(notificationEmailEventOrder)*len(notificationEmailLocales)) + for _, event := range notificationEmailEventOrder { + for _, locale := range notificationEmailLocales { + tmpl, err := s.GetTemplate(ctx, event, locale) + if err != nil { + return nil, err + } + items = append(items, tmpl) + } + } + return items, nil +} + +func (s *NotificationEmailService) GetTemplate(ctx context.Context, event, locale string) (NotificationEmailTemplate, error) { + info, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return NotificationEmailTemplate{}, err + } + normalizedLocale := normalizeNotificationLocale(locale) + official, ok := notificationEmailOfficialTemplates[normalizedEvent][normalizedLocale] + if !ok { + return NotificationEmailTemplate{}, fmt.Errorf("official template not found for %s/%s", normalizedEvent, normalizedLocale) + } + + tmpl := NotificationEmailTemplate{ + Event: normalizedEvent, + Locale: normalizedLocale, + Subject: official.Subject, + HTML: official.HTML, + Placeholders: append([]string(nil), info.Placeholders...), + } + + raw, err := s.settingRepo.GetValue(ctx, notificationEmailTemplateKey(normalizedEvent, normalizedLocale)) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return tmpl, nil + } + return NotificationEmailTemplate{}, err + } + if strings.TrimSpace(raw) == "" { + return tmpl, nil + } + + var stored notificationEmailStoredTemplate + if err := json.Unmarshal([]byte(raw), &stored); err != nil { + return NotificationEmailTemplate{}, fmt.Errorf("decode email template override: %w", err) + } + if err := validateNotificationEmailTemplate(normalizedEvent, stored.Subject, stored.HTML); err != nil { + return NotificationEmailTemplate{}, err + } + tmpl.Subject = stored.Subject + tmpl.HTML = stored.HTML + tmpl.IsCustom = true + updatedAt := stored.UpdatedAt + tmpl.UpdatedAt = &updatedAt + return tmpl, nil +} + +func (s *NotificationEmailService) UpdateTemplate(ctx context.Context, event, locale, subject, htmlBody string) (NotificationEmailTemplate, error) { + _, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return NotificationEmailTemplate{}, err + } + normalizedLocale := normalizeNotificationLocale(locale) + if err := validateNotificationEmailTemplate(normalizedEvent, subject, htmlBody); err != nil { + return NotificationEmailTemplate{}, err + } + stored := notificationEmailStoredTemplate{ + Subject: strings.TrimSpace(subject), + HTML: htmlBody, + UpdatedAt: time.Now().UTC(), + } + payload, err := json.Marshal(stored) + if err != nil { + return NotificationEmailTemplate{}, err + } + if err := s.settingRepo.Set(ctx, notificationEmailTemplateKey(normalizedEvent, normalizedLocale), string(payload)); err != nil { + return NotificationEmailTemplate{}, err + } + return s.GetTemplate(ctx, normalizedEvent, normalizedLocale) +} + +func (s *NotificationEmailService) RestoreOfficialTemplate(ctx context.Context, event, locale string) (NotificationEmailTemplate, error) { + _, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return NotificationEmailTemplate{}, err + } + normalizedLocale := normalizeNotificationLocale(locale) + if err := s.settingRepo.Delete(ctx, notificationEmailTemplateKey(normalizedEvent, normalizedLocale)); err != nil && !errors.Is(err, ErrSettingNotFound) { + return NotificationEmailTemplate{}, err + } + return s.GetTemplate(ctx, normalizedEvent, normalizedLocale) +} + +func (s *NotificationEmailService) PreviewTemplate(ctx context.Context, input NotificationEmailPreviewInput) (NotificationEmailPreview, error) { + _, normalizedEvent, err := s.eventInfo(input.Event) + if err != nil { + return NotificationEmailPreview{}, err + } + normalizedLocale := normalizeNotificationLocale(input.Locale) + subject := input.Subject + htmlBody := input.HTML + if strings.TrimSpace(subject) == "" || strings.TrimSpace(htmlBody) == "" { + tmpl, err := s.GetTemplate(ctx, normalizedEvent, normalizedLocale) + if err != nil { + return NotificationEmailPreview{}, err + } + if strings.TrimSpace(subject) == "" { + subject = tmpl.Subject + } + if strings.TrimSpace(htmlBody) == "" { + htmlBody = tmpl.HTML + } + } + if err := validateNotificationEmailTemplate(normalizedEvent, subject, htmlBody); err != nil { + return NotificationEmailPreview{}, err + } + variables := s.sampleVariables(ctx, normalizedEvent, normalizedLocale) + for key, value := range input.Variables { + variables[key] = value + } + return renderNotificationEmail(normalizedEvent, subject, htmlBody, variables, nil) +} + +func (s *NotificationEmailService) Send(ctx context.Context, input NotificationEmailSendInput) error { + info, normalizedEvent, err := s.eventInfo(input.Event) + if err != nil { + return notificationEmailTemplateErr(err) + } + recipient := strings.TrimSpace(input.RecipientEmail) + if recipient == "" { + return nil + } + if info.Optional { + unsubscribed, err := s.IsUnsubscribed(ctx, recipient, normalizedEvent) + if err != nil { + return err + } + if unsubscribed { + slog.Info("notification email suppressed by unsubscribe preference", "event", normalizedEvent, "recipient_hash", notificationEmailHash(recipient)) + return nil + } + } + + locale := normalizeNotificationLocale(input.Locale) + if strings.TrimSpace(input.Locale) == "" { + locale = s.ResolveRecipientLocale(ctx, input.UserID, recipient) + } + tmpl, err := s.GetTemplate(ctx, normalizedEvent, locale) + if err != nil { + return notificationEmailTemplateErr(err) + } + variables := s.runtimeVariables(ctx, normalizedEvent, locale, input) + rendered, err := renderNotificationEmail(normalizedEvent, tmpl.Subject, tmpl.HTML, variables, input.RawHTMLVariables) + if err != nil { + return notificationEmailTemplateErr(err) + } + + deliveryKey := notificationEmailDeliveryKey(normalizedEvent, input.SourceType, input.SourceID, recipient, input.ReminderKey) + if deliveryKey != "" { + sent, err := s.deliveryExists(ctx, deliveryKey, legacyNotificationEmailDeliveryKey(normalizedEvent, input.SourceType, input.SourceID, recipient, input.ReminderKey)) + if err != nil { + return err + } + if sent { + return nil + } + } + + if s.emailService == nil { + return notificationEmailConfigErr(errors.New("email service is not configured")) + } + if err := s.emailService.SendEmail(ctx, recipient, rendered.Subject, rendered.HTML); err != nil { + return notificationEmailDeliveryErr(err) + } + if deliveryKey != "" { + if err := s.settingRepo.Set(ctx, deliveryKey, time.Now().UTC().Format(time.RFC3339Nano)); err != nil { + return err + } + } + return nil +} + +func (s *NotificationEmailService) RememberRecipientLocale(ctx context.Context, userID int64, email, acceptLanguage string) { + locale := normalizeNotificationLocale(acceptLanguage) + if strings.TrimSpace(acceptLanguage) == "" || s == nil || s.settingRepo == nil { + return + } + if userID > 0 { + _ = s.settingRepo.Set(ctx, notificationEmailLocaleUserKeyPrefix+strconv.FormatInt(userID, 10), locale) + } + if emailHash := notificationEmailHash(email); emailHash != "" { + _ = s.settingRepo.Set(ctx, notificationEmailLocaleEmailKeyPrefix+emailHash, locale) + } +} + +func (s *NotificationEmailService) ResolveRecipientLocale(ctx context.Context, userID int64, email string) string { + if s == nil || s.settingRepo == nil { + return notificationEmailDefaultLocale + } + if userID > 0 { + if locale, err := s.settingRepo.GetValue(ctx, notificationEmailLocaleUserKeyPrefix+strconv.FormatInt(userID, 10)); err == nil && strings.TrimSpace(locale) != "" { + return normalizeNotificationLocale(locale) + } + } + if emailHash := notificationEmailHash(email); emailHash != "" { + if locale, err := s.settingRepo.GetValue(ctx, notificationEmailLocaleEmailKeyPrefix+emailHash); err == nil && strings.TrimSpace(locale) != "" { + return normalizeNotificationLocale(locale) + } + } + return notificationEmailDefaultLocale +} + +func (s *NotificationEmailService) IsUnsubscribed(ctx context.Context, email, event string) (bool, error) { + info, normalizedEvent, err := s.eventInfo(event) + if err != nil { + return false, err + } + if !info.Optional { + return false, nil + } + for _, key := range []string{notificationEmailPreferenceKey(normalizedEvent, email), legacyNotificationEmailPreferenceKey(normalizedEvent, email)} { + if strings.TrimSpace(key) == "" { + continue + } + value, err := s.settingRepo.GetValue(ctx, key) + if err == nil { + return strings.EqualFold(strings.TrimSpace(value), "unsubscribed"), nil + } + if !errors.Is(err, ErrSettingNotFound) { + return false, err + } + } + return false, nil +} + +func (s *NotificationEmailService) Unsubscribe(ctx context.Context, token string) (NotificationEmailUnsubscribeResult, error) { + claims, err := s.parseUnsubscribeToken(ctx, token) + if err != nil { + return NotificationEmailUnsubscribeResult{}, err + } + info, normalizedEvent, err := s.eventInfo(claims.Event) + if err != nil { + return NotificationEmailUnsubscribeResult{}, err + } + if !info.Optional { + return NotificationEmailUnsubscribeResult{}, fmt.Errorf("%s is transactional and cannot be unsubscribed", normalizedEvent) + } + if err := s.settingRepo.Set(ctx, notificationEmailPreferenceKey(normalizedEvent, claims.Email), "unsubscribed"); err != nil { + return NotificationEmailUnsubscribeResult{}, err + } + return NotificationEmailUnsubscribeResult{Event: normalizedEvent, Email: claims.Email, Done: true}, nil +} + +func (s *NotificationEmailService) eventInfo(event string) (NotificationEmailEventInfo, string, error) { + normalized := strings.ToLower(strings.TrimSpace(event)) + info, ok := notificationEmailEventDefinitions[normalized] + if !ok { + return NotificationEmailEventInfo{}, "", fmt.Errorf("unsupported email template event: %s", event) + } + return info, normalized, nil +} + +func (s *NotificationEmailService) sampleVariables(ctx context.Context, event, locale string) map[string]string { + info := notificationEmailEventDefinitions[event] + variables := make(map[string]string, len(info.Placeholders)) + for key, value := range notificationEmailSampleVariables(locale) { + variables[key] = value + } + variables["site_name"] = s.siteName(ctx) + if variables["unsubscribe_url"] == "" && info.Optional { + variables["unsubscribe_url"] = "https://example.com/unsubscribe" + } + return variables +} + +func (s *NotificationEmailService) runtimeVariables(ctx context.Context, event, locale string, input NotificationEmailSendInput) map[string]string { + variables := s.sampleVariables(ctx, event, locale) + for key, value := range input.Variables { + variables[key] = value + } + variables["site_name"] = s.siteName(ctx) + variables["recipient_email"] = input.RecipientEmail + if strings.TrimSpace(input.RecipientName) != "" { + variables["recipient_name"] = input.RecipientName + } + if notificationEmailEventDefinitions[event].Optional { + if unsubscribeURL, err := s.buildUnsubscribeURL(ctx, input.RecipientEmail, event); err == nil { + variables["unsubscribe_url"] = unsubscribeURL + } + } + return variables +} + +func (s *NotificationEmailService) siteName(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return defaultSiteName + } + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || strings.TrimSpace(name) == "" { + return defaultSiteName + } + return strings.TrimSpace(name) +} + +func (s *NotificationEmailService) baseURL(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return "" + } + for _, key := range []string{SettingKeyAPIBaseURL, SettingKeyFrontendURL} { + value, err := s.settingRepo.GetValue(ctx, key) + if err == nil && strings.TrimSpace(value) != "" { + return strings.TrimRight(strings.TrimSpace(value), "/") + } + } + return "" +} + +func (s *NotificationEmailService) buildUnsubscribeURL(ctx context.Context, email, event string) (string, error) { + token, err := s.createUnsubscribeToken(ctx, email, event) + if err != nil { + return "", err + } + path := "/api/v1/settings/email-unsubscribe?token=" + url.QueryEscape(token) + baseURL := s.baseURL(ctx) + if baseURL == "" { + return path, nil + } + return baseURL + path, nil +} + +func (s *NotificationEmailService) createUnsubscribeToken(ctx context.Context, email, event string) (string, error) { + secret, err := s.unsubscribeSecret(ctx) + if err != nil { + return "", err + } + claims := notificationEmailUnsubscribeClaims{Email: strings.TrimSpace(email), Event: event, Exp: time.Now().Add(notificationEmailUnsubscribeTTL).Unix()} + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + signature := signNotificationEmailToken(secret, encodedPayload) + return encodedPayload + "." + signature, nil +} + +func (s *NotificationEmailService) parseUnsubscribeToken(ctx context.Context, token string) (notificationEmailUnsubscribeClaims, error) { + parts := strings.Split(strings.TrimSpace(token), ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token") + } + secret, err := s.unsubscribeSecret(ctx) + if err != nil { + return notificationEmailUnsubscribeClaims{}, err + } + expected := signNotificationEmailToken(secret, parts[0]) + if !hmac.Equal([]byte(expected), []byte(parts[1])) { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token signature") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token payload") + } + var claims notificationEmailUnsubscribeClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token payload") + } + if strings.TrimSpace(claims.Email) == "" || strings.TrimSpace(claims.Event) == "" { + return notificationEmailUnsubscribeClaims{}, errors.New("invalid unsubscribe token claims") + } + if claims.Exp <= time.Now().Unix() { + return notificationEmailUnsubscribeClaims{}, errors.New("unsubscribe token expired") + } + return claims, nil +} + +func (s *NotificationEmailService) unsubscribeSecret(ctx context.Context) (string, error) { + secret, err := s.settingRepo.GetValue(ctx, notificationEmailUnsubscribeSecretKey) + if err == nil && strings.TrimSpace(secret) != "" { + return strings.TrimSpace(secret), nil + } + if err != nil && !errors.Is(err, ErrSettingNotFound) { + return "", err + } + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + secret = base64.RawURLEncoding.EncodeToString(buf) + if err := s.settingRepo.Set(ctx, notificationEmailUnsubscribeSecretKey, secret); err != nil { + return "", err + } + return secret, nil +} + +func (s *NotificationEmailService) deliveryExists(ctx context.Context, keys ...string) (bool, error) { + for _, key := range keys { + if strings.TrimSpace(key) == "" { + continue + } + _, err := s.settingRepo.GetValue(ctx, key) + if err == nil { + return true, nil + } + if !errors.Is(err, ErrSettingNotFound) { + return false, err + } + } + return false, nil +} + +func validateNotificationEmailTemplate(event, subject, htmlBody string) error { + subject = strings.TrimSpace(subject) + if subject == "" { + return errors.New("email subject cannot be empty") + } + if len([]rune(subject)) > notificationEmailMaxSubjectLength { + return fmt.Errorf("email subject cannot exceed %d characters", notificationEmailMaxSubjectLength) + } + if strings.TrimSpace(htmlBody) == "" { + return errors.New("email html cannot be empty") + } + if len([]byte(htmlBody)) > notificationEmailMaxHTMLLength { + return fmt.Errorf("email html cannot exceed %d bytes", notificationEmailMaxHTMLLength) + } + allowed := notificationEmailAllowedPlaceholderSet(event) + for _, placeholder := range notificationEmailPlaceholdersIn(subject + "\n" + htmlBody) { + if _, ok := allowed[placeholder]; !ok { + return fmt.Errorf("unsupported placeholder {{%s}} for event %s", placeholder, event) + } + } + return nil +} + +func renderNotificationEmail(event, subject, htmlBody string, variables map[string]string, rawHTMLVariables map[string]string) (NotificationEmailPreview, error) { + if err := validateNotificationEmailTemplate(event, subject, htmlBody); err != nil { + return NotificationEmailPreview{}, err + } + renderedSubject, err := renderNotificationEmailString(event, subject, variables, nil, false) + if err != nil { + return NotificationEmailPreview{}, err + } + renderedHTML, err := renderNotificationEmailString(event, htmlBody, variables, rawHTMLVariables, true) + if err != nil { + return NotificationEmailPreview{}, err + } + return NotificationEmailPreview{Subject: sanitizeEmailHeader(renderedSubject), HTML: renderedHTML}, nil +} + +func renderNotificationEmailString(event, raw string, variables map[string]string, rawHTMLVariables map[string]string, escapeHTML bool) (string, error) { + allowed := notificationEmailAllowedPlaceholderSet(event) + var renderErr error + rendered := notificationEmailPlaceholderPattern.ReplaceAllStringFunc(raw, func(match string) string { + if renderErr != nil { + return "" + } + parts := notificationEmailPlaceholderPattern.FindStringSubmatch(match) + if len(parts) != 2 { + return "" + } + name := parts[1] + if _, ok := allowed[name]; !ok { + renderErr = fmt.Errorf("unsupported placeholder {{%s}} for event %s", name, event) + return "" + } + value := variables[name] + if escapeHTML && notificationEmailRawHTMLAllowed(event, name) { + if rawHTMLVariables != nil { + if rawValue, ok := rawHTMLVariables[name]; ok { + return rawValue + } + } + } + if strings.HasSuffix(name, "_url") && !isSafeNotificationEmailURL(value) { + value = "" + } + if escapeHTML { + return html.EscapeString(value) + } + return sanitizeEmailHeader(value) + }) + if renderErr != nil { + return "", renderErr + } + return rendered, nil +} + +func notificationEmailRawHTMLAllowed(event, placeholder string) bool { + return event == NotificationEmailEventOpsScheduledReport && placeholder == "report_html" +} + +func notificationEmailAllowedPlaceholderSet(event string) map[string]struct{} { + info := notificationEmailEventDefinitions[event] + allowed := make(map[string]struct{}, len(info.Placeholders)) + for _, placeholder := range info.Placeholders { + allowed[placeholder] = struct{}{} + } + return allowed +} + +func notificationEmailPlaceholdersIn(raw string) []string { + matches := notificationEmailPlaceholderPattern.FindAllStringSubmatch(raw, -1) + seen := make(map[string]struct{}, len(matches)) + out := make([]string, 0, len(matches)) + for _, match := range matches { + if len(match) != 2 { + continue + } + if _, exists := seen[match[1]]; exists { + continue + } + seen[match[1]] = struct{}{} + out = append(out, match[1]) + } + return out +} + +func normalizeNotificationLocale(raw string) string { + trimmed := strings.ToLower(strings.TrimSpace(raw)) + if trimmed == "" { + return notificationEmailDefaultLocale + } + for _, part := range strings.Split(trimmed, ",") { + tag := strings.TrimSpace(strings.Split(part, ";")[0]) + if strings.HasPrefix(tag, "zh") || tag == "cn" { + return notificationEmailLocaleChinese + } + if strings.HasPrefix(tag, "en") { + return notificationEmailDefaultLocale + } + } + return notificationEmailDefaultLocale +} + +func notificationEmailTemplateKey(event, locale string) string { + return notificationEmailTemplateKeyPrefix + event + ":" + locale +} + +func notificationEmailPreferenceKey(event, email string) string { + if strings.TrimSpace(event) == "" || strings.TrimSpace(email) == "" { + return "" + } + identity := strings.TrimSpace(event) + "\x00" + strings.ToLower(strings.TrimSpace(email)) + return notificationEmailPreferenceKeyPrefix + "v2:" + notificationEmailHash(identity) +} + +func legacyNotificationEmailPreferenceKey(event, email string) string { + return notificationEmailPreferenceKeyPrefix + event + ":" + notificationEmailHash(email) +} + +func notificationEmailDeliveryKey(event, sourceType, sourceID, recipient, reminderKey string) string { + if strings.TrimSpace(sourceType) == "" || strings.TrimSpace(sourceID) == "" || strings.TrimSpace(recipient) == "" { + return "" + } + identity := strings.Join([]string{ + strings.ToLower(strings.TrimSpace(event)), + safeNotificationEmailKeyPart(sourceType), + safeNotificationEmailKeyPart(sourceID), + strings.ToLower(strings.TrimSpace(recipient)), + safeNotificationEmailKeyPart(reminderKey), + }, "\x00") + return notificationEmailDeliveryKeyPrefix + "v2:" + notificationEmailHash(identity) +} + +func legacyNotificationEmailDeliveryKey(event, sourceType, sourceID, recipient, reminderKey string) string { + if strings.TrimSpace(sourceType) == "" || strings.TrimSpace(sourceID) == "" || strings.TrimSpace(recipient) == "" { + return "" + } + parts := []string{notificationEmailDeliveryKeyPrefix, event, ":", safeNotificationEmailKeyPart(sourceType), ":", safeNotificationEmailKeyPart(sourceID), ":", notificationEmailHash(recipient)} + if strings.TrimSpace(reminderKey) != "" { + parts = append(parts, ":", safeNotificationEmailKeyPart(reminderKey)) + } + return strings.Join(parts, "") +} + +func notificationEmailHash(value string) string { + trimmed := strings.ToLower(strings.TrimSpace(value)) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return hex.EncodeToString(sum[:]) +} + +func safeNotificationEmailKeyPart(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + var builder strings.Builder + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' || r == '.' { + _, _ = builder.WriteRune(r) + } else { + _, _ = builder.WriteRune('_') + } + } + return builder.String() +} + +func signNotificationEmailToken(secret, payload string) string { + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} + +func isSafeNotificationEmailURL(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return true + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + if parsed.IsAbs() { + scheme := strings.ToLower(parsed.Scheme) + return scheme == "http" || scheme == "https" || scheme == "mailto" + } + return strings.HasPrefix(trimmed, "/") +} + +func notificationEmailSampleVariables(locale string) map[string]string { + if normalizeNotificationLocale(locale) == notificationEmailLocaleChinese { + return map[string]string{ + "site_name": defaultSiteName, + "recipient_name": "张三", + "recipient_email": "user@example.com", + "verification_code": "123456", + "expires_in_minutes": "15", + "reset_url": "https://example.com/reset-password?token=preview", + "subscription_group": "Claude Pro", + "subscription_days": "30", + "expiry_time": "2026-06-18 12:00", + "days_remaining": "3", + "current_balance": "12.34", + "threshold": "20.00", + "recharge_url": "https://example.com/recharge", + "recharge_amount": "50.00", + "order_id": "1024", + "unsubscribe_url": "https://example.com/unsubscribe", + "account_id": "1001", + "account_name": "openai-main", + "platform": "openai", + "quota_dimension": "每日额度", + "quota_used": "80.00", + "quota_limit": "100.00", + "quota_remaining": "20.00", + "quota_threshold": "20%", + "triggered_at": "2026-05-20 12:00:00", + "group_name": "默认分组", + "moderation_category": "violence", + "moderation_score": "0.982", + "violation_count": "2", + "ban_threshold": "3", + "rule_name": "错误率过高", + "severity": "critical", + "alert_status": "firing", + "metric_type": "error_rate", + "operator": ">=", + "metric_value": "12.50", + "threshold_value": "10.00", + "alert_description": "最近 10 分钟错误率超过阈值", + "report_name": "日报", + "report_type": "daily_summary", + "report_start_time": "2026-05-19 12:00", + "report_end_time": "2026-05-20 12:00", + "report_html": "

日报

请求量:1024

", + } + } + return map[string]string{ + "site_name": defaultSiteName, + "recipient_name": "Alex", + "recipient_email": "user@example.com", + "verification_code": "123456", + "expires_in_minutes": "15", + "reset_url": "https://example.com/reset-password?token=preview", + "subscription_group": "Claude Pro", + "subscription_days": "30", + "expiry_time": "2026-06-18 12:00", + "days_remaining": "3", + "current_balance": "12.34", + "threshold": "20.00", + "recharge_url": "https://example.com/recharge", + "recharge_amount": "50.00", + "order_id": "1024", + "unsubscribe_url": "https://example.com/unsubscribe", + "account_id": "1001", + "account_name": "openai-main", + "platform": "openai", + "quota_dimension": "Daily quota", + "quota_used": "80.00", + "quota_limit": "100.00", + "quota_remaining": "20.00", + "quota_threshold": "20%", + "triggered_at": "2026-05-20 12:00:00", + "group_name": "Default group", + "moderation_category": "violence", + "moderation_score": "0.982", + "violation_count": "2", + "ban_threshold": "3", + "rule_name": "High error rate", + "severity": "critical", + "alert_status": "firing", + "metric_type": "error_rate", + "operator": ">=", + "metric_value": "12.50", + "threshold_value": "10.00", + "alert_description": "Error rate exceeded threshold in the last 10 minutes.", + "report_name": "Daily summary", + "report_type": "daily_summary", + "report_start_time": "2026-05-19 12:00", + "report_end_time": "2026-05-20 12:00", + "report_html": "

Daily summary

Requests: 1024

", + } +} + +var notificationEmailEventOrder = []string{ + NotificationEmailEventAuthVerifyCode, + NotificationEmailEventAuthPasswordReset, + NotificationEmailEventNotificationEmailVerifyCode, + NotificationEmailEventSubscriptionPurchaseSuccess, + NotificationEmailEventSubscriptionExpiryReminder, + NotificationEmailEventBalanceLow, + NotificationEmailEventBalanceRechargeSuccess, + NotificationEmailEventAccountQuotaAlert, + NotificationEmailEventContentModerationViolation, + NotificationEmailEventContentModerationDisabled, + NotificationEmailEventOpsAlert, + NotificationEmailEventOpsScheduledReport, +} + +var notificationEmailEventDefinitions = map[string]NotificationEmailEventInfo{ + NotificationEmailEventAuthVerifyCode: { + Event: NotificationEmailEventAuthVerifyCode, + Label: "Email verification code", + Description: "Sent for registration, email binding, OAuth pending email, and TOTP verification flows.", + Category: "auth", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "verification_code", "expires_in_minutes"), + }, + NotificationEmailEventAuthPasswordReset: { + Event: NotificationEmailEventAuthPasswordReset, + Label: "Password reset", + Description: "Sent when a user requests a password reset link.", + Category: "auth", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "reset_url", "expires_in_minutes"), + }, + NotificationEmailEventNotificationEmailVerifyCode: { + Event: NotificationEmailEventNotificationEmailVerifyCode, + Label: "Notification email verification code", + Description: "Sent when a user verifies an extra notification email address.", + Category: "auth", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "verification_code", "expires_in_minutes"), + }, + NotificationEmailEventSubscriptionPurchaseSuccess: { + Event: NotificationEmailEventSubscriptionPurchaseSuccess, + Label: "Subscription purchase success", + Description: "Sent after a subscription purchase is fulfilled.", + Category: "subscription", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "subscription_group", "subscription_days", "expiry_time", "order_id"), + }, + NotificationEmailEventSubscriptionExpiryReminder: { + Event: NotificationEmailEventSubscriptionExpiryReminder, + Label: "Subscription expiry reminder", + Description: "Optional reminder sent before an active subscription expires.", + Category: "subscription", + Optional: true, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "subscription_group", "expiry_time", "days_remaining", "unsubscribe_url"), + }, + NotificationEmailEventBalanceLow: { + Event: NotificationEmailEventBalanceLow, + Label: "Low balance alert", + Description: "Optional alert sent when balance crosses the configured low-balance threshold.", + Category: "billing", + Optional: true, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "current_balance", "threshold", "recharge_url", "unsubscribe_url"), + }, + NotificationEmailEventBalanceRechargeSuccess: { + Event: NotificationEmailEventBalanceRechargeSuccess, + Label: "Balance recharge success", + Description: "Sent after a balance recharge order is fulfilled.", + Category: "billing", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), "recharge_amount", "current_balance", "order_id"), + }, + NotificationEmailEventAccountQuotaAlert: { + Event: NotificationEmailEventAccountQuotaAlert, + Label: "Account quota alert", + Description: "Sent to configured admin notification emails when an upstream account quota threshold is crossed.", + Category: "admin", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "account_id", "account_name", "platform", "quota_dimension", "quota_used", "quota_limit", "quota_remaining", "quota_threshold"), + }, + NotificationEmailEventContentModerationViolation: { + Event: NotificationEmailEventContentModerationViolation, + Label: "Risk control violation notice", + Description: "Sent to users when a request triggers content moderation/risk control rules.", + Category: "risk_control", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "triggered_at", "group_name", "moderation_category", "moderation_score", "violation_count", "ban_threshold"), + }, + NotificationEmailEventContentModerationDisabled: { + Event: NotificationEmailEventContentModerationDisabled, + Label: "Risk control account disabled", + Description: "Sent to users when content moderation automatically disables their account.", + Category: "risk_control", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "triggered_at", "group_name", "moderation_category", "moderation_score", "violation_count", "ban_threshold"), + }, + NotificationEmailEventOpsAlert: { + Event: NotificationEmailEventOpsAlert, + Label: "Ops alert", + Description: "Sent to configured operations recipients when an ops alert rule fires.", + Category: "ops", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "rule_name", "severity", "alert_status", "metric_type", "operator", "metric_value", "threshold_value", "triggered_at", "alert_description"), + }, + NotificationEmailEventOpsScheduledReport: { + Event: NotificationEmailEventOpsScheduledReport, + Label: "Ops scheduled report", + Description: "Sent to configured operations recipients for scheduled daily/weekly/error/account-health reports.", + Category: "ops", + Optional: false, + Placeholders: append(append([]string{}, notificationEmailCommonPlaceholders...), + "report_name", "report_type", "report_start_time", "report_end_time", "report_html"), + }, +} + +var notificationEmailOfficialTemplates = map[string]map[string]notificationEmailOfficialTemplate{ + NotificationEmailEventAuthVerifyCode: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Email verification code", + HTML: notificationEmailCard("#4f46e5", "Email verification code", ` +

Hello {{recipient_name}},

+

Your verification code is:

+

{{verification_code}}

+

This code expires in {{expires_in_minutes}} minutes.

+

If you did not request this code, please ignore this email.

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 邮箱验证码", + HTML: notificationEmailCard("#4f46e5", "邮箱验证码", ` +

{{recipient_name}},您好:

+

您的验证码是:

+

{{verification_code}}

+

验证码将在 {{expires_in_minutes}} 分钟后失效。

+

如果不是您本人操作,请忽略此邮件。

`), + }, + }, + NotificationEmailEventAuthPasswordReset: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Password reset request", + HTML: notificationEmailCard("#7c3aed", "Password reset", ` +

Hello {{recipient_name}},

+

We received a request to reset your password. Click the button below to set a new password.

+

Reset password

+

This link expires in {{expires_in_minutes}} minutes.

+

If the button does not work, copy this link into your browser:
{{reset_url}}

+

If you did not request this, you can safely ignore this email.

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 密码重置请求", + HTML: notificationEmailCard("#7c3aed", "密码重置", ` +

{{recipient_name}},您好:

+

我们收到了您的密码重置请求,请点击下方按钮设置新密码。

+

重置密码

+

此链接将在 {{expires_in_minutes}} 分钟后失效。

+

如果按钮无法点击,请复制以下链接到浏览器中打开:
{{reset_url}}

+

如果不是您本人操作,请忽略此邮件。

`), + }, + }, + NotificationEmailEventNotificationEmailVerifyCode: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Notification email verification code", + HTML: notificationEmailCard("#0ea5e9", "Notification email verification", ` +

Hello {{recipient_name}},

+

You are adding this address as an extra notification email.

+

Your verification code is:

+

{{verification_code}}

+

This code expires in {{expires_in_minutes}} minutes.

+

If you did not request this code, please ignore this email.

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 通知邮箱验证码", + HTML: notificationEmailCard("#0ea5e9", "通知邮箱验证", ` +

{{recipient_name}},您好:

+

您正在添加额外的通知邮箱,请输入以下验证码完成验证。

+

{{verification_code}}

+

验证码将在 {{expires_in_minutes}} 分钟后失效。

+

如果不是您本人操作,请忽略此邮件。

`), + }, + }, + NotificationEmailEventSubscriptionPurchaseSuccess: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Subscription purchase successful", + HTML: notificationEmailCard("#2563eb", "Subscription activated", ` +

Hello {{recipient_name}},

+

Your subscription for {{subscription_group}} has been activated for {{subscription_days}} days.

+

Expiry time: {{expiry_time}}

+

Order ID: {{order_id}}

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 订阅购买成功", + HTML: notificationEmailCard("#2563eb", "订阅已开通", ` +

{{recipient_name}},您好:

+

您的 {{subscription_group}} 订阅已成功开通,有效期 {{subscription_days}} 天。

+

到期时间:{{expiry_time}}

+

订单号:{{order_id}}

`), + }, + }, + NotificationEmailEventSubscriptionExpiryReminder: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Subscription expires in {{days_remaining}} day(s)", + HTML: notificationEmailCard("#f97316", "Subscription expiry reminder", ` +

Hello {{recipient_name}},

+

Your {{subscription_group}} subscription will expire in {{days_remaining}} day(s).

+

Expiry time: {{expiry_time}}

+

Unsubscribe from optional subscription reminders

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 订阅将在 {{days_remaining}} 天后到期", + HTML: notificationEmailCard("#f97316", "订阅到期提醒", ` +

{{recipient_name}},您好:

+

您的 {{subscription_group}} 订阅将在 {{days_remaining}} 天后到期。

+

到期时间:{{expiry_time}}

+

退订此类订阅提醒

`), + }, + }, + NotificationEmailEventBalanceLow: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Low balance alert", + HTML: notificationEmailCard("#d97706", "Low balance alert", ` +

Hello {{recipient_name}},

+

Your current balance is ${{current_balance}}, below the configured alert threshold of ${{threshold}}.

+

Please recharge in time to avoid service interruption.

+

Recharge now

+

Unsubscribe from optional balance alerts

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 余额不足提醒", + HTML: notificationEmailCard("#d97706", "余额不足提醒", ` +

{{recipient_name}},您好:

+

您当前余额为 ${{current_balance}},已低于提醒阈值 ${{threshold}}

+

请及时充值以免服务中断。

+

立即充值

+

退订此类余额提醒

`), + }, + }, + NotificationEmailEventBalanceRechargeSuccess: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Balance recharge successful", + HTML: notificationEmailCard("#16a34a", "Recharge successful", ` +

Hello {{recipient_name}},

+

Your balance recharge of ${{recharge_amount}} has been completed.

+

Current balance: ${{current_balance}}

+

Order ID: {{order_id}}

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 余额充值成功", + HTML: notificationEmailCard("#16a34a", "余额充值成功", ` +

{{recipient_name}},您好:

+

您的余额充值 ${{recharge_amount}} 已完成。

+

当前余额:${{current_balance}}

+

订单号:{{order_id}}

`), + }, + }, + NotificationEmailEventAccountQuotaAlert: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Account quota alert - {{account_name}}", + HTML: notificationEmailCard("#dc2626", "Account quota alert", ` +

The upstream account {{account_name}} has crossed its configured quota alert threshold.

+ + + + + + + +
Account ID{{account_id}}
Platform{{platform}}
Dimension{{quota_dimension}}
Used / Limit{{quota_used}} / {{quota_limit}}
Remaining{{quota_remaining}}
Threshold{{quota_threshold}}
`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 账号限额告警 - {{account_name}}", + HTML: notificationEmailCard("#dc2626", "账号限额告警", ` +

上游账号 {{account_name}} 已触发配置的额度告警阈值。

+ + + + + + + +
账号 ID{{account_id}}
平台{{platform}}
维度{{quota_dimension}}
已用 / 限额{{quota_used}} / {{quota_limit}}
剩余额度{{quota_remaining}}
告警阈值{{quota_threshold}}
`), + }, + }, + NotificationEmailEventContentModerationViolation: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Risk control notice", + HTML: notificationEmailCard("#ef4444", "Risk control notice", ` +

Hello {{recipient_name}},

+

Your API request triggered the platform content moderation/risk-control policy.

+ + + + + +
Triggered at{{triggered_at}}
Group{{group_name}}
Category / Score{{moderation_category}} / {{moderation_score}}
Violation count{{violation_count}} / {{ban_threshold}}
+

Please review your request content to avoid future service interruptions.

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 账户风控提醒", + HTML: notificationEmailCard("#ef4444", "账户风控提醒", ` +

{{recipient_name}},您好:

+

您的 API 请求触发了平台内容审核/风控策略。

+ + + + + +
触发时间{{triggered_at}}
所属分组{{group_name}}
命中类别 / 分数{{moderation_category}} / {{moderation_score}}
累计触发次数{{violation_count}} / {{ban_threshold}}
+

请检查请求内容,避免后续服务受到影响。

`), + }, + }, + NotificationEmailEventContentModerationDisabled: { + notificationEmailDefaultLocale: { + Subject: "[{{site_name}}] Account disabled by risk control", + HTML: notificationEmailCard("#b91c1c", "Account disabled", ` +

Hello {{recipient_name}},

+

Your account has repeatedly triggered platform content moderation/risk-control rules and has been automatically disabled.

+ + + + + +
Disabled at{{triggered_at}}
Group{{group_name}}
Category / Score{{moderation_category}} / {{moderation_score}}
Violation count{{violation_count}} / {{ban_threshold}}
+

Please contact the administrator if you need to appeal or restore access.

`), + }, + notificationEmailLocaleChinese: { + Subject: "[{{site_name}}] 账户已被禁用", + HTML: notificationEmailCard("#b91c1c", "账户已被禁用", ` +

{{recipient_name}},您好:

+

您的账户在统计周期内多次触发平台内容审核/风控规则,系统已自动禁用该账户。

+ + + + + +
禁用时间{{triggered_at}}
所属分组{{group_name}}
命中类别 / 分数{{moderation_category}} / {{moderation_score}}
累计触发次数{{violation_count}} / {{ban_threshold}}
+

如需申诉或恢复账号,请联系平台管理员处理。

`), + }, + }, + NotificationEmailEventOpsAlert: { + notificationEmailDefaultLocale: { + Subject: "[Ops Alert][{{severity}}] {{rule_name}}", + HTML: notificationEmailCard("#ea580c", "Ops alert", ` +

Rule: {{rule_name}}

+

Severity: {{severity}}

+

Status: {{alert_status}}

+

Metric: {{metric_type}} {{operator}} {{metric_value}} (threshold {{threshold_value}})

+

Fired at: {{triggered_at}}

+

Description: {{alert_description}}

`), + }, + notificationEmailLocaleChinese: { + Subject: "[运维告警][{{severity}}] {{rule_name}}", + HTML: notificationEmailCard("#ea580c", "运维告警", ` +

规则:{{rule_name}}

+

严重级别:{{severity}}

+

状态:{{alert_status}}

+

指标:{{metric_type}} {{operator}} {{metric_value}}(阈值 {{threshold_value}})

+

触发时间:{{triggered_at}}

+

说明:{{alert_description}}

`), + }, + }, + NotificationEmailEventOpsScheduledReport: { + notificationEmailDefaultLocale: { + Subject: "[Ops Report] {{report_name}}", + HTML: notificationEmailCard("#0891b2", "Ops report", ` +

Report: {{report_name}}

+

Type: {{report_type}}

+

Range: {{report_start_time}} - {{report_end_time}}

+
{{report_html}}
`), + }, + notificationEmailLocaleChinese: { + Subject: "[运维报表] {{report_name}}", + HTML: notificationEmailCard("#0891b2", "运维报表", ` +

报表:{{report_name}}

+

类型:{{report_type}}

+

时间范围:{{report_start_time}} - {{report_end_time}}

+
{{report_html}}
`), + }, + }, +} + +func notificationEmailCard(accent, title, content string) string { + return ` + + + + + + + +
+

` + title + `

+
` + content + `
+ +
+ +` +} diff --git a/backend/internal/service/notification_email_service_test.go b/backend/internal/service/notification_email_service_test.go new file mode 100644 index 00000000..38987f6f --- /dev/null +++ b/backend/internal/service/notification_email_service_test.go @@ -0,0 +1,571 @@ +package service + +import ( + "bufio" + "context" + "errors" + "net" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNotificationEmailPreviewEscapesHTMLAndSanitizesSubject(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{ + Event: NotificationEmailEventBalanceLow, + Locale: "en-US,en;q=0.9", + Subject: "Low balance for {{recipient_name}}\r\nInjected", + HTML: `

{{recipient_name}}

Recharge`, + Variables: map[string]string{ + "recipient_name": ``, + "recharge_url": `javascript:alert(1)`, + }, + }) + require.NoError(t, err) + require.NotContains(t, preview.Subject, "\r") + require.NotContains(t, preview.Subject, "\n") + require.Contains(t, preview.Subject, `Low balance for Injected`) + require.Contains(t, preview.HTML, `<script>alert("x")</script>`) + require.NotContains(t, preview.HTML, `javascript:alert(1)`) + require.Contains(t, preview.HTML, `href=""`) +} + +func TestNotificationEmailTemplateOverrideAndRestore(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + + official, err := svc.GetTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "en") + require.NoError(t, err) + require.False(t, official.IsCustom) + + updated, err := svc.UpdateTemplate( + ctx, + NotificationEmailEventBalanceRechargeSuccess, + "zh-Hans", + "充值完成:{{recharge_amount}}", + "

{{recipient_name}} 已充值 {{recharge_amount}}

", + ) + require.NoError(t, err) + require.True(t, updated.IsCustom) + require.Equal(t, "zh", updated.Locale) + require.Equal(t, "充值完成:{{recharge_amount}}", updated.Subject) + require.NotNil(t, updated.UpdatedAt) + + restored, err := svc.RestoreOfficialTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "zh") + require.NoError(t, err) + require.False(t, restored.IsCustom) + require.NotEqual(t, updated.Subject, restored.Subject) + _, err = repo.GetValue(ctx, notificationEmailTemplateKey(NotificationEmailEventBalanceRechargeSuccess, "zh")) + require.ErrorIs(t, err, ErrSettingNotFound) +} + +func TestNotificationEmailTemplateRejectsUnsupportedPlaceholder(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + _, err := svc.UpdateTemplate( + ctx, + NotificationEmailEventSubscriptionPurchaseSuccess, + "en", + "Purchased {{not_allowed}}", + "

{{subscription_group}}

", + ) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported placeholder") +} + +func TestNotificationEmailAuthTemplatesAreListedAndPreviewable(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + infos := svc.ListEventInfos() + events := make(map[string]NotificationEmailEventInfo, len(infos)) + for _, info := range infos { + events[info.Event] = info + } + require.Contains(t, events, NotificationEmailEventAuthVerifyCode) + require.Contains(t, events, NotificationEmailEventAuthPasswordReset) + require.False(t, events[NotificationEmailEventAuthVerifyCode].Optional) + require.False(t, events[NotificationEmailEventAuthPasswordReset].Optional) + require.Contains(t, events[NotificationEmailEventAuthVerifyCode].Placeholders, "verification_code") + require.Contains(t, events[NotificationEmailEventAuthPasswordReset].Placeholders, "reset_url") + + verifyPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{ + Event: NotificationEmailEventAuthVerifyCode, + Locale: "zh-CN", + Variables: map[string]string{ + "verification_code": "654321", + "expires_in_minutes": "15", + }, + }) + require.NoError(t, err) + require.Contains(t, verifyPreview.Subject, "邮箱验证码") + require.Contains(t, verifyPreview.HTML, "654321") + + resetPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{ + Event: NotificationEmailEventAuthPasswordReset, + Locale: "en", + Variables: map[string]string{ + "reset_url": "https://example.com/reset?token=abc", + "expires_in_minutes": "30", + }, + }) + require.NoError(t, err) + require.Contains(t, resetPreview.Subject, "Password reset") + require.Contains(t, resetPreview.HTML, "https://example.com/reset?token=abc") +} + +func TestNotificationEmailAdditionalEventsAreListedAndPreviewable(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + infos := svc.ListEventInfos() + events := make(map[string]NotificationEmailEventInfo, len(infos)) + for _, info := range infos { + events[info.Event] = info + } + + checks := []struct { + event string + placeholder string + }{ + {NotificationEmailEventNotificationEmailVerifyCode, "verification_code"}, + {NotificationEmailEventAccountQuotaAlert, "account_name"}, + {NotificationEmailEventContentModerationViolation, "moderation_category"}, + {NotificationEmailEventContentModerationDisabled, "violation_count"}, + {NotificationEmailEventOpsAlert, "rule_name"}, + {NotificationEmailEventOpsScheduledReport, "report_html"}, + } + + for _, check := range checks { + info, ok := events[check.event] + require.Truef(t, ok, "expected %s to be listed", check.event) + require.False(t, info.Optional) + require.Contains(t, info.Placeholders, check.placeholder) + + preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{Event: check.event, Locale: "zh"}) + require.NoError(t, err) + require.NotEmpty(t, preview.Subject) + require.NotEmpty(t, preview.HTML) + } +} + +func TestNotificationEmailRawHTMLVariablesAreTrustedOnlyForHTMLPlaceholders(t *testing.T) { + require.True(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "report_html")) + require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "recipient_name")) + require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsAlert, "report_html")) + + preview, err := renderNotificationEmail( + NotificationEmailEventOpsScheduledReport, + "Report for {{recipient_name}}", + `
{{report_html}}

{{recipient_name}}

`, + map[string]string{ + "recipient_name": ``, + "report_html": `

escaped report

`, + }, + map[string]string{ + "report_html": `
trusted report
`, + }, + ) + require.NoError(t, err) + require.Contains(t, preview.HTML, `
trusted report
`) + require.NotContains(t, preview.HTML, `escaped report`) + require.Contains(t, preview.HTML, `<script>alert("x")</script>`) + require.Contains(t, preview.Subject, ``) + + preview, err = renderNotificationEmail( + NotificationEmailEventOpsScheduledReport, + "Recipient {{recipient_name}}", + `

{{recipient_name}}

`, + map[string]string{"recipient_name": `escaped`}, + map[string]string{"recipient_name": `raw`}, + ) + require.NoError(t, err) + require.Contains(t, preview.HTML, `<em>escaped</em>`) + require.NotContains(t, preview.HTML, `raw`) +} + +func TestNotificationEmailFallbackClassification(t *testing.T) { + templateErr := notificationEmailTemplateErr(errors.New("bad template")) + configErr := notificationEmailConfigErr(errors.New("missing email service")) + deliveryErr := notificationEmailDeliveryErr(errors.New("smtp timeout")) + + require.True(t, shouldFallbackNotificationEmail(templateErr)) + require.True(t, shouldFallbackNotificationEmail(configErr)) + require.False(t, shouldFallbackNotificationEmail(deliveryErr)) + require.True(t, isNotificationEmailDeliveryError(deliveryErr)) + require.False(t, isNotificationEmailDeliveryError(templateErr)) + require.False(t, shouldFallbackNotificationEmail(nil)) +} + +func TestEmailQueueTasksPreserveLocaleHints(t *testing.T) { + queue := &EmailQueueService{taskChan: make(chan EmailTask, 2)} + require.NoError(t, queue.EnqueueVerifyCode("user@example.com", "Sub2API", "zh-CN")) + require.NoError(t, queue.EnqueuePasswordReset("user@example.com", "Sub2API", "https://example.com/reset", "en-US")) + + verifyTask := <-queue.taskChan + require.Equal(t, TaskTypeVerifyCode, verifyTask.TaskType) + require.Equal(t, "zh-CN", verifyTask.Locale) + + resetTask := <-queue.taskChan + require.Equal(t, TaskTypePasswordReset, resetTask.TaskType) + require.Equal(t, "en-US", resetTask.Locale) +} + +func TestOpsScheduledReportDeliverySourceIDIncludesReportIdentity(t *testing.T) { + report := &opsScheduledReport{Name: "日报", ReportType: "daily_summary", Schedule: "0 9 * * *"} + sourceID := opsScheduledReportDeliverySourceID(report) + require.Contains(t, sourceID, "daily_summary") + require.Contains(t, sourceID, "日报") + require.Contains(t, sourceID, "0 9 * * *") + require.NotEqual(t, sourceID, opsScheduledReportDeliverySourceID(&opsScheduledReport{Name: "周报", ReportType: "weekly_summary", Schedule: "0 9 * * 1"})) + require.Equal(t, "scheduled_report", opsScheduledReportDeliverySourceID(nil)) +} + +func TestNotificationEmailUnsubscribeOnlyAllowsOptionalEvents(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + token, err := svc.createUnsubscribeToken(ctx, "User@Example.com", NotificationEmailEventBalanceLow) + require.NoError(t, err) + result, err := svc.Unsubscribe(ctx, token) + require.NoError(t, err) + require.True(t, result.Done) + require.Equal(t, NotificationEmailEventBalanceLow, result.Event) + unsubscribed, err := svc.IsUnsubscribed(ctx, "user@example.com", NotificationEmailEventBalanceLow) + require.NoError(t, err) + require.True(t, unsubscribed) + + transactionalToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventBalanceRechargeSuccess) + require.NoError(t, err) + _, err = svc.Unsubscribe(ctx, transactionalToken) + require.Error(t, err) + require.Contains(t, err.Error(), "transactional") + + authToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventAuthVerifyCode) + require.NoError(t, err) + _, err = svc.Unsubscribe(ctx, authToken) + require.Error(t, err) + require.Contains(t, err.Error(), "transactional") +} + +func TestNotificationEmailLocaleMemoryNormalizesAcceptLanguage(t *testing.T) { + ctx := context.Background() + svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil) + + svc.RememberRecipientLocale(ctx, 42, "User@Example.com", "zh-CN,zh;q=0.9,en;q=0.8") + require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 42, "user@example.com")) + require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 0, "user@example.com")) +} + +func TestNotificationEmailDeliveryKeyUsesShortStableHash(t *testing.T) { + key := notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "User@Example.com", + "7d", + ) + require.NotEmpty(t, key) + require.LessOrEqual(t, len(key), 100) + require.True(t, strings.HasPrefix(key, notificationEmailDeliveryKeyPrefix+"v2:")) + require.Equal(t, key, notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "7d", + )) + require.NotEqual(t, key, notificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "3d", + )) + + legacyKey := legacyNotificationEmailDeliveryKey( + NotificationEmailEventSubscriptionExpiryReminder, + "user_subscription", + "1234567890", + "user@example.com", + "7d", + ) + require.Greater(t, len(legacyKey), 100) +} + +func TestNotificationEmailPreferenceKeyUsesShortStableHashAndReadsLegacyKey(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + + key := notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "User@Example.com") + require.NotEmpty(t, key) + require.LessOrEqual(t, len(key), 100) + require.True(t, strings.HasPrefix(key, notificationEmailPreferenceKeyPrefix+"v2:")) + require.Equal(t, key, notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com")) + + legacyKey := legacyNotificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com") + require.Greater(t, len(legacyKey), 100) + require.NoError(t, repo.Set(ctx, legacyKey, "unsubscribed")) + + unsubscribed, err := svc.IsUnsubscribed(ctx, "User@Example.com", NotificationEmailEventSubscriptionExpiryReminder) + require.NoError(t, err) + require.True(t, unsubscribed) +} + +func TestNotificationEmailSendDeduplicatesSubscriptionExpiryReminder(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + smtpServer := startNotificationEmailTestSMTPServer(t) + require.NoError(t, repo.SetMultiple(ctx, smtpServer.settings())) + + emailSvc := NewEmailService(repo, nil) + svc := NewNotificationEmailService(repo, emailSvc) + input := NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: "User@Example.com", + RecipientName: "User", + UserID: 42, + SourceType: "user_subscription", + SourceID: "1234567890", + ReminderKey: "7d", + Variables: map[string]string{ + "subscription_group": "Codex", + "expiry_time": "2026-05-27 12:00", + "days_remaining": "7", + }, + } + + require.NoError(t, svc.Send(ctx, input)) + require.Equal(t, int64(1), smtpServer.messageCount()) + + key := notificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey) + require.LessOrEqual(t, len(key), 100) + _, err := repo.GetValue(ctx, key) + require.NoError(t, err) + + require.NoError(t, svc.Send(ctx, input)) + require.Equal(t, int64(1), smtpServer.messageCount()) +} + +func TestNotificationEmailSendRespectsLegacyDeliveryKey(t *testing.T) { + ctx := context.Background() + repo := newNotificationEmailMemorySettingRepo() + svc := NewNotificationEmailService(repo, nil) + input := NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: "user@example.com", + SourceType: "user_subscription", + SourceID: "1234567890", + ReminderKey: "7d", + } + legacyKey := legacyNotificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey) + require.NoError(t, repo.Set(ctx, legacyKey, "sent")) + + require.NoError(t, svc.Send(ctx, input)) +} + +type notificationEmailMemorySettingRepo struct { + mu sync.RWMutex + values map[string]string +} + +func newNotificationEmailMemorySettingRepo() *notificationEmailMemorySettingRepo { + return ¬ificationEmailMemorySettingRepo{values: make(map[string]string)} +} + +func (r *notificationEmailMemorySettingRepo) Get(_ context.Context, key string) (*Setting, error) { + r.mu.RLock() + defer r.mu.RUnlock() + value, ok := r.values[key] + if !ok { + return nil, ErrSettingNotFound + } + return &Setting{Key: key, Value: value}, nil +} + +func (r *notificationEmailMemorySettingRepo) GetValue(ctx context.Context, key string) (string, error) { + setting, err := r.Get(ctx, key) + if err != nil { + return "", err + } + return setting.Value, nil +} + +func (r *notificationEmailMemorySettingRepo) Set(_ context.Context, key, value string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.values[key] = value + return nil +} + +func (r *notificationEmailMemorySettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *notificationEmailMemorySettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + r.mu.Lock() + defer r.mu.Unlock() + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *notificationEmailMemorySettingRepo) GetAll(_ context.Context) (map[string]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *notificationEmailMemorySettingRepo) Delete(_ context.Context, key string) error { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.values[key]; !ok { + return ErrSettingNotFound + } + delete(r.values, key) + return nil +} + +func TestNotificationEmailMemorySettingRepoSatisfiesInterface(t *testing.T) { + var _ SettingRepository = (*notificationEmailMemorySettingRepo)(nil) + require.False(t, strings.Contains(notificationEmailPreferenceKey(NotificationEmailEventBalanceLow, "User@Example.com"), "User@Example.com")) +} + +type notificationEmailTestSMTPServer struct { + listener net.Listener + wg sync.WaitGroup + messages atomic.Int64 +} + +func startNotificationEmailTestSMTPServer(t *testing.T) *notificationEmailTestSMTPServer { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + server := ¬ificationEmailTestSMTPServer{listener: listener} + server.wg.Add(1) + go server.serve() + t.Cleanup(server.close) + return server +} + +func (s *notificationEmailTestSMTPServer) settings() map[string]string { + host, port, _ := net.SplitHostPort(s.listener.Addr().String()) + return map[string]string{ + SettingKeySMTPHost: host, + SettingKeySMTPPort: port, + SettingKeySMTPUsername: "user", + SettingKeySMTPPassword: "password", + SettingKeySMTPFrom: "noreply@example.com", + SettingKeySMTPFromName: "Sub2API", + SettingKeySMTPUseTLS: "false", + } +} + +func (s *notificationEmailTestSMTPServer) messageCount() int64 { + return s.messages.Load() +} + +func (s *notificationEmailTestSMTPServer) close() { + _ = s.listener.Close() + s.wg.Wait() +} + +func (s *notificationEmailTestSMTPServer) serve() { + defer s.wg.Done() + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + s.handleConn(conn) + } +} + +func (s *notificationEmailTestSMTPServer) handleConn(conn net.Conn) { + defer func() { _ = conn.Close() }() + rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + writeLine := func(line string) bool { + if _, err := rw.WriteString(line + "\r\n"); err != nil { + return false + } + return rw.Flush() == nil + } + if !writeLine("220 localhost ESMTP") { + return + } + for { + line, err := rw.ReadString('\n') + if err != nil { + return + } + cmd := strings.ToUpper(strings.TrimRight(line, "\r\n")) + switch { + case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"): + if _, err := rw.WriteString("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil { + return + } + if err := rw.Flush(); err != nil { + return + } + case strings.HasPrefix(cmd, "AUTH"): + if !writeLine("235 2.7.0 Authentication successful") { + return + } + case strings.HasPrefix(cmd, "MAIL FROM:"): + if !writeLine("250 2.1.0 OK") { + return + } + case strings.HasPrefix(cmd, "RCPT TO:"): + if !writeLine("250 2.1.5 OK") { + return + } + case strings.HasPrefix(cmd, "DATA"): + if !writeLine("354 End data with .") { + return + } + for { + dataLine, err := rw.ReadString('\n') + if err != nil { + return + } + if strings.TrimRight(dataLine, "\r\n") == "." { + break + } + } + s.messages.Add(1) + if !writeLine("250 2.0.0 OK") { + return + } + case strings.HasPrefix(cmd, "QUIT"): + _ = writeLine("221 2.0.0 Bye") + return + default: + if !writeLine("250 OK") { + return + } + } + } +} diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback.go b/backend/internal/service/openai_gateway_responses_chat_fallback.go new file mode 100644 index 00000000..1d28a9c2 --- /dev/null +++ b/backend/internal/service/openai_gateway_responses_chat_fallback.go @@ -0,0 +1,428 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// forwardResponsesViaRawChatCompletions serves /v1/responses clients through an +// upstream that only supports /v1/chat/completions. +func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + var responsesReq apicompat.ResponsesRequest + if err := json.Unmarshal(body, &responsesReq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "Failed to parse request body", + }, + }) + return nil, fmt.Errorf("parse responses request: %w", err) + } + originalModel := strings.TrimSpace(responsesReq.Model) + if originalModel == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "model is required", + }, + }) + return nil, fmt.Errorf("missing model in request") + } + + clientStream := responsesReq.Stream + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel) + serviceTier := extractOpenAIServiceTierFromBody(body) + + chatReq, err := apicompat.ResponsesToChatCompletionsRequest(&responsesReq) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + }, + }) + return nil, fmt.Errorf("convert responses to chat completions: %w", err) + } + + billingModel := resolveOpenAIForwardModel(account, originalModel, "") + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + chatReq.Model = upstreamModel + if clientStream { + chatReq.StreamOptions = &apicompat.ChatStreamOptions{IncludeUsage: true} + } + + chatBody, err := json.Marshal(chatReq) + if err != nil { + return nil, fmt.Errorf("marshal chat completions fallback request: %w", err) + } + chatBody, err = s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, chatBody) + if err != nil { + var blocked *OpenAIFastBlockedError + if errors.As(err, &blocked) { + writeOpenAIFastPolicyBlockedResponse(c, blocked) + } + return nil, err + } + if serviceTier == nil { + serviceTier = extractOpenAIServiceTierFromBody(chatBody) + } + + logger.L().Debug("openai responses: forwarding via raw chat completions", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), + zap.Bool("stream", clientStream), + ) + + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + targetURL := buildOpenAIChatCompletionsURL(validatedURL) + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(chatBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + if clientStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } else { + upstreamReq.Header.Set("Accept", "application/json") + } + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiCCRawAllowedHeaders[lowerKey] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + if customUA := account.GetOpenAIUserAgent(); customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleErrorResponse(ctx, resp, c, account, chatBody) + } + + if clientStream { + return s.streamChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) + } + return s.bufferChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) +} + +func (s *OpenAIGatewayService) bufferChatCompletionsAsResponses( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Failed to read upstream response", + }, + }) + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + var ccResp apicompat.ChatCompletionsResponse + if err := json.Unmarshal(respBody, &ccResp); err != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Failed to parse upstream response", + }, + }) + return nil, fmt.Errorf("parse chat completions response: %w", err) + } + responsesResp := apicompat.ChatCompletionsResponseToResponses(&ccResp, originalModel) + + usage := OpenAIUsage{} + if parsed, ok := extractOpenAIUsageFromJSONBytes(respBody); ok { + usage = parsed + } + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, responsesResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +func (s *OpenAIGatewayService) streamChatCompletionsAsResponses( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + headersWritten := false + writeStreamHeaders := func() { + if headersWritten { + return + } + headersWritten = true + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + } + + state := apicompat.NewChatCompletionsToResponsesStreamState(originalModel) + var usage OpenAIUsage + var firstTokenMs *int + clientDisconnected := false + sawDone := false + + writeEvents := func(events []apicompat.ResponsesStreamEvent) { + if clientDisconnected || len(events) == 0 { + return + } + writeStreamHeaders() + for _, event := range events { + sse, err := apicompat.ResponsesEventToSSE(event) + if err != nil { + logger.L().Warn("openai responses chat fallback: failed to marshal stream event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Debug("openai responses chat fallback: client disconnected, continuing to drain upstream for billing", + zap.Error(err), + zap.String("request_id", requestID), + ) + return + } + } + c.Writer.Flush() + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + for scanner.Scan() { + line := scanner.Text() + payload, ok := extractOpenAISSEDataLine(line) + if !ok { + continue + } + payload = strings.TrimSpace(payload) + if payload == "" { + continue + } + if payload == "[DONE]" { + sawDone = true + break + } + + if u := extractCCStreamUsage(payload); u != nil { + usage = *u + } + + var chunk apicompat.ChatCompletionsChunk + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + logger.L().Warn("openai responses chat fallback: failed to parse chat stream chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if firstTokenMs == nil && !isOpenAIChatUsageOnlyStreamChunk(payload) && chatChunkStartsResponsesOutput(&chunk) { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + writeEvents(apicompat.ChatCompletionsChunkToResponsesEvents(&chunk, state)) + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai responses chat fallback: stream read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, fmt.Errorf("stream usage incomplete: %w", err) + } + + writeEvents(apicompat.FinalizeChatCompletionsResponsesStream(state)) + if !clientDisconnected { + writeStreamHeaders() + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + clientDisconnected = true + } + if !clientDisconnected { + c.Writer.Flush() + } + } + if !sawDone { + logger.L().Debug("openai responses chat fallback: upstream stream ended without done sentinel", + zap.String("request_id", requestID), + ) + } + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func chatChunkStartsResponsesOutput(chunk *apicompat.ChatCompletionsChunk) bool { + if chunk == nil { + return false + } + for _, choice := range chunk.Choices { + if choice.Delta.Content != nil || choice.Delta.ReasoningContent != nil || len(choice.Delta.ToolCalls) > 0 { + return true + } + } + return false +} diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback_test.go b/backend/internal/service/openai_gateway_responses_chat_fallback_test.go new file mode 100644 index 00000000..78df2202 --- /dev/null +++ b/backend/internal/service/openai_gateway_responses_chat_fallback_test.go @@ -0,0 +1,145 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestForwardResponses_ForceChatCompletionsRoutesNonStreamingToChatCompletions(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_chat_json"}}, + Body: io.NopCloser(strings.NewReader( + `{"id":"chatcmpl_json","object":"chat.completion","model":"gpt-5.4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5,"prompt_tokens_details":{"cached_tokens":1}}}`, + )), + }} + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists()) + require.Equal(t, "response", gjson.Get(rec.Body.String(), "object").String()) + require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String()) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + require.False(t, result.Stream) +} + +func TestForwardResponses_ForceChatCompletionsRoutesStreamingToChatCompletions(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","input":"hello","stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"he"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"llo"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_resp_chat_stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) + require.Contains(t, rec.Body.String(), "event: response.output_text.delta") + require.Contains(t, rec.Body.String(), `"delta":"he"`) + require.Contains(t, rec.Body.String(), "event: response.completed") + require.Contains(t, rec.Body.String(), `"input_tokens":4`) + require.Contains(t, rec.Body.String(), "data: [DONE]") + require.Equal(t, 4, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.True(t, result.Stream) + require.NotNil(t, result.FirstTokenMs) +} + +func TestForwardResponses_AutoSupportedAccountStillUsesResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_native"}}, + Body: io.NopCloser(strings.NewReader( + `{"id":"resp_native","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"ok"}],"status":"completed"}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}`, + )), + }} + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + account.Extra = map[string]any{ + openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeAuto), + openai_compat.ExtraKeyResponsesSupported: true, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "http://upstream.example/v1/responses", upstream.lastReq.URL.String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "input").Exists()) + require.False(t, gjson.GetBytes(upstream.lastBody, "messages").Exists()) + require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String()) +} + +func forceChatResponsesFallbackAccount() *Account { + account := rawChatCompletionsTestAccount() + account.Extra = map[string]any{ + openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeForceChatCompletions), + } + return account +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index cfaf5bff..f13c4748 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -24,6 +24,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/cespare/xxhash/v2" @@ -2018,6 +2019,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco originalBody := body reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel + + if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) { + return s.forwardResponsesViaRawChatCompletions(ctx, c, account, body) + } + compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body) setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge) @@ -3231,6 +3237,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( req.Header.Set("user-agent", codexCLIUserAgent) } + // 浏览器型 UA 兜底:仅 OAuth(ChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器 + // (Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA,避免 Cloudflare 触发 JS 质询。 + s.overrideBrowserUserAgent(ctx, account, req) + if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") } @@ -3947,6 +3957,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("user-agent", codexCLIUserAgent) } + // 浏览器型 UA 兜底:仅 OAuth(ChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器 + // (Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA,避免 Cloudflare 触发 JS 质询。 + s.overrideBrowserUserAgent(ctx, account, req) + // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -3955,6 +3969,30 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. return req, nil } +// overrideBrowserUserAgent 检查请求的最终 user-agent,若为浏览器 UA 则替换为后台配置的 Codex UA。 +// 用于规避 Cloudflare 对浏览器型 UA 在 ChatGPT 内部接口上的访问质询。 +// 影响范围严格限定:仅 OAuth(Codex/ChatGPT 内部接口)账号生效;API Key 等其他账号原样透传。 +// 仅在识别为浏览器(Mozilla/...)时改写,其他 CLI/工具 UA 不动。 +func (s *OpenAIGatewayService) overrideBrowserUserAgent(ctx context.Context, account *Account, req *http.Request) { + if req == nil || account == nil { + return + } + if account.Type != AccountTypeOAuth { + return + } + currentUA := req.Header.Get("user-agent") + if !openai.IsBrowserUserAgent(currentUA) { + return + } + codexUA := DefaultOpenAICodexUserAgent + if s != nil && s.settingService != nil { + if v := strings.TrimSpace(s.settingService.GetOpenAICodexUserAgent(ctx)); v != "" { + codexUA = v + } + } + req.Header.Set("user-agent", codexUA) +} + func (s *OpenAIGatewayService) handleErrorResponse( ctx context.Context, resp *http.Response, diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index c89c2aaf..56272c26 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st tool := []byte(`{"type":"image_generation","action":"","model":""}`) tool, _ = sjson.SetBytes(tool, "action", action) tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel)) + if shouldPassOpenAIImagesN(toolModel, parsed.N) { + tool, _ = sjson.SetBytes(tool, "n", parsed.N) + } for _, field := range []struct { path string @@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st return req, nil } +func shouldPassOpenAIImagesN(model string, n int) bool { + if n <= 1 { + return false + } + return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3") +} + func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) { if gjson.GetBytes(payload, "type").String() != "response.completed" { return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type") @@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( account.Type, len(parsed.Uploads), ) - if parsed.N > 1 { - logger.LegacyPrintf( - "service.openai_gateway", - "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s", - parsed.N, - requestModel, - parsed.Endpoint, - ) - } - upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) defer releaseUpstreamCtx() diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 35789d21..d47c52ca 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) return openAIImageTestSSEEvent{}, false } -func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { +func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) { gin.SetMode(gin.TestMode) - body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`) req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { "X-Request-Id": []string{"req_img_123"}, }, Body: io.NopCloser(strings.NewReader( - "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + "data: [DONE]\n\n", )), }, @@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { require.NotNil(t, result) require.Equal(t, "gpt-image-2", result.Model) require.Equal(t, "gpt-image-2", result.UpstreamModel) - require.Equal(t, 1, result.ImageCount) + require.Equal(t, 3, result.ImageCount) require.Equal(t, 11, result.Usage.InputTokens) require.Equal(t, 22, result.Usage.OutputTokens) require.Equal(t, 7, result.Usage.ImageOutputTokens) @@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String()) require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String()) - require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists()) + require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int()) require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String()) - require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) - require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) + require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3) + require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String()) + require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String()) + require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) + require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String()) } func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) { @@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) } -func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) { +func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) { parsed := &OpenAIImagesRequest{ Endpoint: openAIImagesGenerationsEndpoint, Model: "gpt-image-2", @@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") require.NoError(t, err) require.NotNil(t, body) - require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int()) require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String()) require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String()) } +func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: "dall-e-3", + Prompt: "draw a cat", + N: 2, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String()) +} + func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) { parsed := &OpenAIImagesRequest{ Endpoint: openAIImagesEditsEndpoint, diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 11c5d5ce..c6a58a1b 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -686,6 +686,21 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt if !s.emailLimiter.Allow(time.Now().UTC()) { continue } + if s.emailService.notificationEmailService != nil { + if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventOpsAlert, + RecipientEmail: addr, + RecipientName: emailRecipientName(addr), + SourceType: "ops_alert", + SourceID: fmt.Sprintf("%d", event.ID), + Variables: opsAlertEmailVariables(rule, event), + }); err == nil { + anySent = true + continue + } else if !shouldFallbackNotificationEmail(err) { + continue + } + } if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil { // Ignore per-recipient failures; continue best-effort. continue @@ -699,6 +714,46 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt return anySent } +func opsAlertEmailVariables(rule *OpsAlertRule, event *OpsAlertEvent) map[string]string { + variables := map[string]string{ + "rule_name": "-", + "severity": "-", + "alert_status": "-", + "metric_type": "-", + "operator": "-", + "metric_value": "-", + "threshold_value": "-", + "triggered_at": time.Now().UTC().Format(time.RFC3339), + "alert_description": "-", + } + if rule != nil { + variables["rule_name"] = strings.TrimSpace(rule.Name) + variables["severity"] = strings.TrimSpace(rule.Severity) + variables["metric_type"] = strings.TrimSpace(rule.MetricType) + variables["operator"] = strings.TrimSpace(rule.Operator) + variables["threshold_value"] = fmt.Sprintf("%.2f", rule.Threshold) + if strings.TrimSpace(rule.Description) != "" { + variables["alert_description"] = strings.TrimSpace(rule.Description) + } + } + if event != nil { + variables["alert_status"] = strings.TrimSpace(event.Status) + if event.MetricValue != nil { + variables["metric_value"] = fmt.Sprintf("%.2f", *event.MetricValue) + } + if event.ThresholdValue != nil { + variables["threshold_value"] = fmt.Sprintf("%.2f", *event.ThresholdValue) + } + if !event.FiredAt.IsZero() { + variables["triggered_at"] = event.FiredAt.UTC().Format(time.RFC3339) + } + if strings.TrimSpace(event.Description) != "" { + variables["alert_description"] = strings.TrimSpace(event.Description) + } + } + return variables +} + func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string { if rule == nil || event == nil { return "" diff --git a/backend/internal/service/ops_scheduled_report_service.go b/backend/internal/service/ops_scheduled_report_service.go index 98b2045d..54aad114 100644 --- a/backend/internal/service/ops_scheduled_report_service.go +++ b/backend/internal/service/ops_scheduled_report_service.go @@ -337,6 +337,7 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc } subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name)) + templateVariables := opsScheduledReportEmailVariables(report, now) attempts := 0 for _, to := range recipients { @@ -345,6 +346,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc continue } attempts++ + if s.emailService.notificationEmailService != nil { + if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventOpsScheduledReport, + RecipientEmail: addr, + RecipientName: emailRecipientName(addr), + SourceType: "ops_scheduled_report", + SourceID: opsScheduledReportDeliverySourceID(report), + ReminderKey: now.UTC().Format("2006-01-02T15:04"), + Variables: templateVariables, + RawHTMLVariables: map[string]string{ + "report_html": content, + }, + }); err == nil { + continue + } else if !shouldFallbackNotificationEmail(err) { + continue + } + } if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil { // Ignore per-recipient failures; continue best-effort. continue @@ -353,6 +372,46 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc return attempts, nil } +func opsScheduledReportDeliverySourceID(report *opsScheduledReport) string { + if report == nil { + return "scheduled_report" + } + parts := []string{ + strings.TrimSpace(report.ReportType), + strings.TrimSpace(report.Name), + strings.TrimSpace(report.Schedule), + } + joined := strings.Trim(strings.Join(parts, ":"), ":") + if joined == "" { + return "scheduled_report" + } + return joined +} + +func opsScheduledReportEmailVariables(report *opsScheduledReport, now time.Time) map[string]string { + end := now.UTC() + start := end + name := "Ops report" + reportType := "scheduled_report" + if report != nil { + if strings.TrimSpace(report.Name) != "" { + name = strings.TrimSpace(report.Name) + } + if strings.TrimSpace(report.ReportType) != "" { + reportType = strings.TrimSpace(report.ReportType) + } + if report.TimeRange > 0 { + start = end.Add(-report.TimeRange) + } + } + return map[string]string{ + "report_name": name, + "report_type": reportType, + "report_start_time": start.Format(time.RFC3339), + "report_end_time": end.Format(time.RFC3339), + } +} + func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) { if s == nil || s.opsService == nil || report == nil { return "", fmt.Errorf("service not initialized") diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 8a26e868..b6b19ca0 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -310,9 +310,87 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde "creditedAmount": o.Amount, "payAmount": o.PayAmount, }) + s.dispatchPaymentFulfillmentNotification(o, auditAction) return nil } +func (s *PaymentService) dispatchPaymentFulfillmentNotification(o *dbent.PaymentOrder, auditAction string) { + if s == nil || s.notificationEmailService == nil || o == nil { + return + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) + defer cancel() + var err error + switch auditAction { + case "RECHARGE_SUCCESS": + err = s.sendBalanceRechargeSuccessNotification(ctx, o) + case "SUBSCRIPTION_SUCCESS": + err = s.sendSubscriptionPurchaseSuccessNotification(ctx, o) + default: + return + } + if err != nil { + slog.Warn("payment fulfillment notification email failed", "order_id", o.ID, "action", auditAction, "err", err.Error()) + } + }() +} + +func (s *PaymentService) sendBalanceRechargeSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error { + currentBalance := "" + if s.userRepo != nil { + if user, err := s.userRepo.GetByID(ctx, o.UserID); err == nil && user != nil { + currentBalance = fmt.Sprintf("%.2f", user.Balance) + } + } + return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventBalanceRechargeSuccess, + RecipientEmail: o.UserEmail, + RecipientName: firstNonEmpty(o.UserName, o.UserEmail), + UserID: o.UserID, + SourceType: "payment_order", + SourceID: strconv.FormatInt(o.ID, 10), + Variables: map[string]string{ + "recharge_amount": fmt.Sprintf("%.2f", o.Amount), + "current_balance": currentBalance, + "order_id": strconv.FormatInt(o.ID, 10), + }, + }) +} + +func (s *PaymentService) sendSubscriptionPurchaseSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error { + variables := map[string]string{ + "subscription_group": "Subscription", + "subscription_days": "", + "expiry_time": "", + "order_id": strconv.FormatInt(o.ID, 10), + } + if o.SubscriptionDays != nil { + variables["subscription_days"] = strconv.Itoa(*o.SubscriptionDays) + } + if o.SubscriptionGroupID != nil { + if s.groupRepo != nil { + if group, err := s.groupRepo.GetByID(ctx, *o.SubscriptionGroupID); err == nil && group != nil && strings.TrimSpace(group.Name) != "" { + variables["subscription_group"] = group.Name + } + } + if s.subscriptionSvc != nil { + if sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID); err == nil && sub != nil { + variables["expiry_time"] = sub.ExpiresAt.Format("2006-01-02 15:04") + } + } + } + return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionPurchaseSuccess, + RecipientEmail: o.UserEmail, + RecipientName: firstNonEmpty(o.UserName, o.UserEmail), + UserID: o.UserID, + SourceType: "payment_order", + SourceID: strconv.FormatInt(o.ID, 10), + Variables: variables, + }) +} + func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error { o, err := s.entClient.PaymentOrder.Get(ctx, oid) if err != nil { diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index e6cc4b3c..83edb9e1 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -48,6 +48,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest if user.Status != payment.EntityStatusActive { return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled") } + if s.notificationEmailService != nil { + s.notificationEmailService.RememberRecipientLocale(ctx, req.UserID, user.Email, req.Locale) + } orderAmount := req.Amount limitAmount := req.Amount if plan != nil { diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 42553840..2759aba1 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -83,6 +83,7 @@ type CreateOrderRequest struct { PaymentSource string OrderType string PlanID int64 + Locale string } type CreateOrderResponse struct { @@ -174,18 +175,19 @@ type TopUserStat struct { // --- Service --- type PaymentService struct { - providerMu sync.Mutex - providersLoaded bool - entClient *dbent.Client - registry *payment.Registry - loadBalancer payment.LoadBalancer - redeemService *RedeemService - subscriptionSvc *SubscriptionService - configService *PaymentConfigService - userRepo UserRepository - groupRepo GroupRepository - resumeService *PaymentResumeService - affiliateService *AffiliateService + providerMu sync.Mutex + providersLoaded bool + entClient *dbent.Client + registry *payment.Registry + loadBalancer payment.LoadBalancer + redeemService *RedeemService + subscriptionSvc *SubscriptionService + configService *PaymentConfigService + userRepo UserRepository + groupRepo GroupRepository + resumeService *PaymentResumeService + affiliateService *AffiliateService + notificationEmailService *NotificationEmailService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService { @@ -194,6 +196,10 @@ func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, load return svc } +func (s *PaymentService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { + s.notificationEmailService = notificationEmailService +} + // --- Provider Registry --- // EnsureProviders lazily initializes the provider registry on first call. diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a5c16b1f..bd99e341 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -128,6 +128,19 @@ const antigravityUserAgentVersionCacheTTL = 60 * time.Second const antigravityUserAgentVersionErrorTTL = 5 * time.Second const antigravityUserAgentVersionDBTimeout = 5 * time.Second +// DefaultOpenAICodexUserAgent OpenAI Codex 默认 User-Agent(用于规避 Cloudflare 对浏览器 UA 的质询) +const DefaultOpenAICodexUserAgent = "codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)" + +// cachedOpenAICodexUserAgent 缓存 OpenAI Codex UA(进程内缓存,60s TTL) +type cachedOpenAICodexUserAgent struct { + value string + expiresAt int64 // unix nano +} + +const openAICodexUserAgentCacheTTL = 60 * time.Second +const openAICodexUserAgentErrorTTL = 5 * time.Second +const openAICodexUserAgentDBTimeout = 5 * time.Second + // DefaultSubscriptionGroupReader validates group references used by default subscriptions. type DefaultSubscriptionGroupReader interface { GetByID(ctx context.Context, id int64) (*Group, error) @@ -148,6 +161,8 @@ type SettingService struct { webSearchManagerBuilder WebSearchManagerBuilder antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion antigravityUAVersionSF singleflight.Group + openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent + openAICodexUASF singleflight.Group } type ProviderDefaultGrantSettings struct { @@ -907,6 +922,55 @@ func (s *SettingService) GetAntigravityUserAgentVersion(ctx context.Context) str return fallback } +// GetOpenAICodexUserAgent 返回 OpenAI Codex 上游请求使用的 User-Agent。 +// 后台设置优先;为空时回退到内置默认值。 +func (s *SettingService) GetOpenAICodexUserAgent(ctx context.Context) string { + fallback := DefaultOpenAICodexUserAgent + if s == nil || s.settingRepo == nil { + return fallback + } + if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value + } + } + + result, _, _ := s.openAICodexUASF.Do("openai_codex_user_agent", func() (any, error) { + if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value, nil + } + } + if ctx == nil { + ctx = context.Background() + } + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAICodexUserAgentDBTimeout) + defer cancel() + value, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpenAICodexUserAgent) + if err != nil && !errors.Is(err, ErrSettingNotFound) { + slog.Warn("failed to get openai codex user agent setting", "error", err) + s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{ + value: fallback, + expiresAt: time.Now().Add(openAICodexUserAgentErrorTTL).UnixNano(), + }) + return fallback, nil + } + ua := strings.TrimSpace(value) + if ua == "" { + ua = fallback + } + s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{ + value: ua, + expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(), + }) + return ua, nil + }) + if ua, ok := result.(string); ok && ua != "" { + return ua + } + return fallback +} + // SetOnUpdateCallback sets a callback function to be called when settings are updated // This is used for cache invalidation (e.g., HTML cache in frontend server) func (s *SettingService) SetOnUpdateCallback(callback func()) { @@ -1706,6 +1770,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection) updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl) updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion) + updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent) updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) @@ -1788,6 +1853,15 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { version: antigravityUserAgentVersion, expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(), }) + s.openAICodexUASF.Forget("openai_codex_user_agent") + codexUA := strings.TrimSpace(settings.OpenAICodexUserAgent) + if codexUA == "" { + codexUA = DefaultOpenAICodexUserAgent + } + s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{ + value: codexUA, + expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(), + }) openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ enabled: settings.OpenAIAdvancedSchedulerEnabled, @@ -2529,6 +2603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyEnableAnthropicCacheTTL1hInjection: "false", SettingKeyRewriteMessageCacheControl: strconv.FormatBool(s.defaultRewriteMessageCacheControl()), SettingKeyAntigravityUserAgentVersion: "", + SettingKeyOpenAICodexUserAgent: "", SettingPaymentVisibleMethodAlipaySource: "", SettingPaymentVisibleMethodWxpaySource: "", SettingPaymentVisibleMethodAlipayEnabled: "false", @@ -3041,6 +3116,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl() } result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion]) + result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent]) // Web search emulation: quick enabled check from the JSON config if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ea5fa57c..1e5e8b1c 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -193,6 +193,7 @@ type SystemSettings struct { EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false) RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false) AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值 + OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent;空值使用内置默认 // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 diff --git a/backend/internal/service/subscription_expiry_service.go b/backend/internal/service/subscription_expiry_service.go index ce6b32b8..9b3a0309 100644 --- a/backend/internal/service/subscription_expiry_service.go +++ b/backend/internal/service/subscription_expiry_service.go @@ -2,18 +2,23 @@ package service import ( "context" + "fmt" "log" + "strconv" "sync" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) // SubscriptionExpiryService periodically updates expired subscription status. type SubscriptionExpiryService struct { - userSubRepo UserSubscriptionRepository - interval time.Duration - stopCh chan struct{} - stopOnce sync.Once - wg sync.WaitGroup + userSubRepo UserSubscriptionRepository + notificationEmailService *NotificationEmailService + interval time.Duration + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup } func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService { @@ -24,6 +29,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv } } +func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { + s.notificationEmailService = notificationEmailService +} + func (s *SubscriptionExpiryService) Start() { if s == nil || s.userSubRepo == nil || s.interval <= 0 { return @@ -68,4 +77,50 @@ func (s *SubscriptionExpiryService) runOnce() { if updated > 0 { log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated) } + s.sendExpiryReminders(ctx) +} + +func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) { + if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil { + return + } + for page := 1; ; page++ { + subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc") + if err != nil { + log.Printf("[SubscriptionExpiry] List active subscriptions for reminder failed: %v", err) + return + } + for i := range subs { + s.sendExpiryReminderIfDue(ctx, &subs[i]) + } + if pag == nil || page >= pag.Pages || len(subs) == 0 { + return + } + } +} + +func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) { + if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" { + return + } + daysRemaining := sub.DaysRemaining() + if daysRemaining != 7 && daysRemaining != 3 && daysRemaining != 1 { + return + } + if err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventSubscriptionExpiryReminder, + RecipientEmail: sub.User.Email, + RecipientName: firstNonEmpty(sub.User.Username, sub.User.Email), + UserID: sub.UserID, + SourceType: "user_subscription", + SourceID: strconv.FormatInt(sub.ID, 10), + ReminderKey: fmt.Sprintf("%dd", daysRemaining), + Variables: map[string]string{ + "subscription_group": sub.Group.Name, + "expiry_time": sub.ExpiresAt.Format("2006-01-02 15:04"), + "days_remaining": strconv.Itoa(daysRemaining), + }, + }); err != nil { + log.Printf("[SubscriptionExpiry] Send expiry reminder failed: subscription=%d user=%d err=%v", sub.ID, sub.UserID, err) + } } diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go index 052739ed..6a0989c3 100644 --- a/backend/internal/service/totp_service.go +++ b/backend/internal/service/totp_service.go @@ -517,7 +517,7 @@ func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMe } // SendVerifyCode sends an email verification code for TOTP operations -func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error { +func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64, locale ...string) error { // Check if email verification is enabled if !s.settingService.IsEmailVerifyEnabled(ctx) { return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled") @@ -533,5 +533,5 @@ func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error { siteName := s.settingService.GetSiteName(ctx) // Send verification code via queue - return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName) + return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName, firstEmailLocale(locale)) } diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index d64f01e0..db572cc3 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -324,6 +324,30 @@ func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64, return stats, nil } +// GetAPIKeyDailyUsage returns daily usage stats for a user's API key. +func (s *UsageService) GetAPIKeyDailyUsage(ctx context.Context, userID, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.APIKeyDailyUsagePoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, "day", userID, apiKeyID, 0, 0, "", nil, nil, nil) + if err != nil { + return nil, fmt.Errorf("get api key daily usage: %w", err) + } + + points := make([]usagestats.APIKeyDailyUsagePoint, 0, len(trend)) + for _, row := range trend { + points = append(points, usagestats.APIKeyDailyUsagePoint{ + Date: row.Date, + Requests: row.Requests, + InputTokens: row.InputTokens, + OutputTokens: row.OutputTokens, + CacheReadTokens: row.CacheReadTokens, + CacheWriteTokens: row.CacheCreationTokens, + TotalTokens: row.TotalTokens, + Cost: row.Cost, + ActualCost: row.ActualCost, + }) + } + return points, nil +} + // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 61e9f846..0c958909 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -1122,7 +1122,7 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { } // SendNotifyEmailCode sends a verification code to the extra notification email. -func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error { +func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache, locale ...string) error { if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil { return err } @@ -1134,7 +1134,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema // Send email first — if SMTP fails, don't write cache or increment counters, // so the user is not locked out by cooldown/rate-limit for a code they never received. - if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil { + if err := s.sendNotifyVerifyEmail(ctx, emailService, userID, email, code, firstEmailLocale(locale)); err != nil { return err } @@ -1180,13 +1180,33 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str } // sendNotifyVerifyEmail builds and sends the verification email. -func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error { +func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, userID int64, email, code, locale string) error { siteName := "Sub2API" if s.settingRepo != nil { if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" { siteName = name } } + if emailService.notificationEmailService != nil { + if err := emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{ + Event: NotificationEmailEventNotificationEmailVerifyCode, + Locale: locale, + RecipientEmail: email, + RecipientName: emailRecipientName(email), + UserID: userID, + Variables: map[string]string{ + "verification_code": code, + "expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)), + }, + }); err == nil { + return nil + } else { + if !shouldFallbackNotificationEmail(err) { + return err + } + slog.Warn("template notification email verification failed; falling back to built-in body", "recipient_hash", notificationEmailHash(email), "err", err.Error()) + } + } subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName) body := buildNotifyVerifyEmailBody(code, siteName) return emailService.SendEmail(ctx, email, subject, body) diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 71117d0b..cc303d43 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -154,8 +154,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe } // ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService. -func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService { +func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService { svc := NewSubscriptionExpiryService(userSubRepo, time.Minute) + svc.SetNotificationEmailService(notificationEmailService) svc.Start() return svc } @@ -484,6 +485,7 @@ var ProviderSet = wire.NewSet( ProvideOpsCleanupService, ProvideOpsScheduledReportService, NewEmailService, + NewNotificationEmailService, ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, @@ -520,7 +522,7 @@ var ProviderSet = wire.NewSet( NewContentModerationService, NewAffiliateService, ProvidePaymentConfigService, - NewPaymentService, + ProvidePaymentService, ProvidePaymentOrderExpiryService, ProvideBalanceNotifyService, ProvideWindsurfAuthService, @@ -648,8 +650,17 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep } // ProvideBalanceNotifyService creates BalanceNotifyService -func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService { - return NewBalanceNotifyService(emailService, settingRepo, accountRepo) +func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository, notificationEmailService *NotificationEmailService) *BalanceNotifyService { + svc := NewBalanceNotifyService(emailService, settingRepo, accountRepo) + svc.SetNotificationEmailService(notificationEmailService) + return svc +} + +// ProvidePaymentService creates PaymentService and attaches notification email delivery. +func ProvidePaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService, notificationEmailService *NotificationEmailService) *PaymentService { + svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo, affiliateService) + svc.SetNotificationEmailService(notificationEmailService) + return svc } // ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService. diff --git a/frontend/src/__tests__/setup.ts b/frontend/src/__tests__/setup.ts index decb2a37..b777b22e 100644 --- a/frontend/src/__tests__/setup.ts +++ b/frontend/src/__tests__/setup.ts @@ -5,6 +5,45 @@ import { config } from '@vue/test-utils' import { vi } from 'vitest' +function createMemoryStorage(): Storage { + const values = new Map() + + return { + get length() { + return values.size + }, + clear() { + values.clear() + }, + getItem(key: string) { + return values.has(key) ? values.get(key)! : null + }, + key(index: number) { + return Array.from(values.keys())[index] ?? null + }, + removeItem(key: string) { + values.delete(key) + }, + setItem(key: string, value: string) { + values.set(key, String(value)) + } + } +} + +if (typeof globalThis.localStorage === 'undefined' || typeof globalThis.localStorage.getItem !== 'function') { + Object.defineProperty(globalThis, 'localStorage', { + configurable: true, + value: createMemoryStorage() + }) +} + +if (typeof window !== 'undefined' && typeof window.localStorage.getItem !== 'function') { + Object.defineProperty(window, 'localStorage', { + configurable: true, + value: globalThis.localStorage + }) +} + // Mock requestIdleCallback (Safari < 15 不支持) if (typeof globalThis.requestIdleCallback === 'undefined') { globalThis.requestIdleCallback = ((callback: IdleRequestCallback) => { diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 8550ea3f..5632325e 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -505,6 +505,7 @@ export interface SystemSettings { enable_anthropic_cache_ttl_1h_injection: boolean; rewrite_message_cache_control: boolean; antigravity_user_agent_version: string; + openai_codex_user_agent: string; web_search_emulation_enabled?: boolean; // Payment configuration @@ -726,6 +727,7 @@ export interface UpdateSettingsRequest { enable_anthropic_cache_ttl_1h_injection?: boolean; rewrite_message_cache_control?: boolean; antigravity_user_agent_version?: string; + openai_codex_user_agent?: string; // Payment configuration payment_enabled?: boolean; risk_control_enabled?: boolean; @@ -854,6 +856,105 @@ export async function sendTestEmail( return data; } +// ==================== Email Template Settings ==================== + +export interface EmailTemplateOption { + value: string; + label?: string; + description?: string; +} + +export type EmailTemplateEventOption = string | EmailTemplateOption; + +export interface EmailTemplateSummary { + event: string; + locale: string; + subject: string; + is_custom?: boolean; + updated_at?: string; +} + +export interface EmailTemplateListResponse { + events: EmailTemplateEventOption[]; + locales: string[]; + templates?: EmailTemplateSummary[]; + placeholders?: string[]; +} + +export interface EmailTemplateDetail { + event: string; + locale: string; + subject: string; + html: string; + is_custom?: boolean; + updated_at?: string; + placeholders?: string[]; +} + +export interface UpdateEmailTemplateRequest { + subject: string; + html: string; +} + +export interface PreviewEmailTemplateRequest extends UpdateEmailTemplateRequest { + event: string; + locale: string; +} + +export interface EmailTemplatePreviewResponse { + subject: string; + html: string; +} + +export async function getEmailTemplates(): Promise { + const { data } = await apiClient.get( + "/admin/settings/email-templates", + ); + return data; +} + +export async function getEmailTemplate( + event: string, + locale: string, +): Promise { + const { data } = await apiClient.get( + `/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`, + ); + return data; +} + +export async function updateEmailTemplate( + event: string, + locale: string, + request: UpdateEmailTemplateRequest, +): Promise { + const { data } = await apiClient.put( + `/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`, + request, + ); + return data; +} + +export async function restoreOfficialEmailTemplate( + event: string, + locale: string, +): Promise { + const { data } = await apiClient.post( + `/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}/restore-official`, + ); + return data; +} + +export async function previewEmailTemplate( + request: PreviewEmailTemplateRequest, +): Promise { + const { data } = await apiClient.post( + "/admin/settings/email-template-preview", + request, + ); + return data; +} + /** * Admin API Key status response */ @@ -1160,6 +1261,11 @@ export const settingsAPI = { updateSettings, testSmtpConnection, sendTestEmail, + getEmailTemplates, + getEmailTemplate, + updateEmailTemplate, + restoreOfficialEmailTemplate, + previewEmailTemplate, getAdminApiKey, regenerateAdminApiKey, deleteAdminApiKey, diff --git a/frontend/src/api/usage.ts b/frontend/src/api/usage.ts index 7169b698..ee08ee9d 100644 --- a/frontend/src/api/usage.ts +++ b/frontend/src/api/usage.ts @@ -69,6 +69,25 @@ export interface ModelStatsResponse { end_date: string } +export interface ApiKeyDailyUsagePoint { + date: string + requests: number + input_tokens: number + output_tokens: number + cache_read_tokens: number + cache_write_tokens: number + total_tokens: number + cost: number + actual_cost: number +} + +export interface ApiKeyDailyUsageResponse { + items: ApiKeyDailyUsagePoint[] + days: number + start_date: string + end_date: string +} + /** * List usage logs with optional filters * @param page - Page number (default: 1) @@ -234,6 +253,23 @@ export async function getDashboardModels(params?: { return data } +/** + * Get daily usage details for one API key owned by the current user. + * @param apiKeyId - API key ID + * @param days - Number of days to include (1-90) + * @returns Daily usage detail rows + */ +export async function getMyApiKeyDailyUsage( + apiKeyId: number, + days: number = 30 +): Promise { + const { data } = await apiClient.get( + `/user/api-keys/${apiKeyId}/usage/daily`, + { params: { days } } + ) + return data +} + export interface BatchApiKeyUsageStats { api_key_id: number today_actual_cost: number @@ -279,6 +315,7 @@ export const usageAPI = { getDashboardStats, getDashboardTrend, getDashboardModels, + getMyApiKeyDailyUsage, getDashboardApiKeysUsage } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 7fb034de..0b975672 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -123,19 +123,23 @@ export default { dateRangeToday: 'Today', dateRange7d: '7 Days', dateRange30d: '30 Days', + dateRange90d: '90 Days', dateRangeCustom: 'Custom', apply: 'Apply', used: 'Used', detailInfo: 'Detail Information', tokenStats: 'Token Statistics', + dailyDetail: 'Daily Detail', modelStats: 'Model Usage Statistics', // Table headers + date: 'Date', model: 'Model', requests: 'Requests', inputTokens: 'Input Tokens', outputTokens: 'Output Tokens', cacheCreationTokens: 'Cache Creation', cacheReadTokens: 'Cache Read', + cacheWriteTokens: 'Cache Write', totalTokens: 'Total Tokens', cost: 'Cost', // Status @@ -179,6 +183,7 @@ export default { querySuccess: 'Query successful', queryFailed: 'Query failed', queryFailedRetry: 'Query failed, please try again later', + noDailyUsage: 'No daily usage data', }, // Setup Wizard @@ -4176,6 +4181,22 @@ export default { }, userPrefix: 'User #{id}', exportCsv: 'Export CSV', + batchUpdate: 'Batch Update', + batchUpdateTitle: 'Batch Update Redeem Codes', + selectedCount: '{count} redeem code(s) selected', + clearSelection: 'Clear selection', + selectCodesFirst: 'Select redeem codes first', + noBatchFieldsSelected: 'Select at least one field to update', + batchUpdateSuccess: 'Updated {count} redeem code(s)', + failedToBatchUpdate: 'Failed to batch update redeem codes', + batchFields: { + status: 'Status', + expiresAt: 'Expires At', + notes: 'Notes', + group: 'Group' + }, + batchNotesPlaceholder: 'Enter the new note, or leave blank to clear it', + clearGroup: 'Clear group', deleteAllUnused: 'Delete All Unused Codes', deleteCode: 'Delete Redeem Code', deleteCodeConfirm: @@ -5515,6 +5536,9 @@ export default { antigravityUserAgentVersion: 'Antigravity UA Version', antigravityUserAgentVersionPlaceholder: '1.23.2', antigravityUserAgentVersionHint: 'Leave empty to use ANTIGRAVITY_USER_AGENT_VERSION or the built-in default 1.23.2; when set, the admin setting takes precedence.', + openaiCodexUserAgent: 'OpenAI Codex UA', + openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)', + openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.', }, webSearchEmulation: { title: 'Web Search Emulation', @@ -5854,6 +5878,36 @@ export default { sending: 'Sending...', enterRecipientHint: 'Please enter a recipient email address' }, + emailTemplates: { + title: 'Email Templates', + description: 'Customize notification email subjects and HTML content for each event and locale.', + event: 'Event', + locale: 'Locale', + localeEn: 'English', + localeZh: 'Chinese', + subject: 'Subject', + subjectPlaceholder: 'Enter the email subject', + html: 'HTML Template', + htmlPlaceholder: 'Edit the email HTML template', + placeholders: 'Available Placeholders', + placeholdersHelp: 'Click a placeholder to copy it. The backend replaces these values when sending emails.', + livePreview: 'Live Preview', + previewSecurityHint: 'Preview HTML is generated by the backend preview endpoint and displayed in a sandboxed iframe with scripts disabled.', + preview: 'Preview / Refresh', + previewing: 'Previewing...', + save: 'Save Template', + saving: 'Saving...', + restoreOfficial: 'Restore Official', + restoring: 'Restoring...', + restoreConfirm: 'Restore the official template for this event and locale? Your custom version will be replaced.', + restoreSuccess: 'Official template restored', + saveSuccess: 'Email template saved', + placeholderCopied: 'Placeholder copied', + validationRequired: 'Subject and HTML template are required', + empty: 'No email template events or locales are available yet.', + noPreview: 'Refresh the preview to see the rendered email subject.', + customized: 'Customized' + }, opsMonitoring: { title: 'Ops Monitoring', description: 'Enable ops monitoring for troubleshooting and health visibility', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 0c15558f..6b595bef 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -123,19 +123,23 @@ export default { dateRangeToday: '今日', dateRange7d: '7 天', dateRange30d: '30 天', + dateRange90d: '90 天', dateRangeCustom: '自定义', apply: '应用', used: '已使用', detailInfo: '详细信息', tokenStats: 'Token 统计', + dailyDetail: '按日明细', modelStats: '模型用量统计', // Table headers + date: '日期', model: '模型', requests: '请求数', inputTokens: '输入 Tokens', outputTokens: '输出 Tokens', cacheCreationTokens: '缓存创建', cacheReadTokens: '缓存读取', + cacheWriteTokens: '缓存写入', totalTokens: '总 Tokens', cost: '费用', // Status @@ -179,6 +183,7 @@ export default { querySuccess: '查询成功', queryFailed: '查询失败', queryFailedRetry: '查询失败,请稍后重试', + noDailyUsage: '暂无按日用量数据', }, // Setup Wizard @@ -4310,6 +4315,22 @@ export default { used: '已使用', searchCodes: '搜索兑换码或邮箱...', exportCsv: '导出 CSV', + batchUpdate: '批量修改', + batchUpdateTitle: '批量修改兑换码', + selectedCount: '已选择 {count} 个兑换码', + clearSelection: '清空选择', + selectCodesFirst: '请先选择兑换码', + noBatchFieldsSelected: '请至少勾选一个要修改的字段', + batchUpdateSuccess: '成功修改 {count} 个兑换码', + failedToBatchUpdate: '批量修改兑换码失败', + batchFields: { + status: '状态', + expiresAt: '过期时间', + notes: '备注', + group: '分组' + }, + batchNotesPlaceholder: '输入新的备注,留空可清空备注', + clearGroup: '清空分组', deleteAllUnused: '删除全部未使用', deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。', deleteAllUnusedConfirm: '确定要删除全部未使用的兑换码吗?此操作无法撤销。', @@ -5673,6 +5694,9 @@ export default { antigravityUserAgentVersion: 'Antigravity UA 版本', antigravityUserAgentVersionPlaceholder: '1.23.2', antigravityUserAgentVersionHint: '留空时使用 ANTIGRAVITY_USER_AGENT_VERSION 或内置默认值 1.23.2;填写后后台设置优先。', + openaiCodexUserAgent: 'OpenAI Codex UA', + openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)', + openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器(Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。', }, webSearchEmulation: { title: 'Web Search 模拟', @@ -6014,6 +6038,36 @@ export default { sending: '发送中...', enterRecipientHint: '请输入收件人邮箱地址' }, + emailTemplates: { + title: '邮件模板', + description: '按事件和语言自定义通知邮件主题与 HTML 内容。', + event: '事件', + locale: '语言', + localeEn: '英文', + localeZh: '中文', + subject: '主题', + subjectPlaceholder: '输入邮件主题', + html: 'HTML 模板', + htmlPlaceholder: '编辑邮件 HTML 模板', + placeholders: '可用占位符', + placeholdersHelp: '点击占位符可复制。后端发送邮件时会替换这些值。', + livePreview: '实时预览', + previewSecurityHint: '预览 HTML 由后端预览接口生成,并在禁用脚本的沙盒 iframe 中展示。', + preview: '预览 / 刷新', + previewing: '预览中...', + save: '保存模板', + saving: '保存中...', + restoreOfficial: '恢复官方模板', + restoring: '恢复中...', + restoreConfirm: '确定恢复此事件和语言的官方模板吗?当前自定义版本将被替换。', + restoreSuccess: '已恢复官方模板', + saveSuccess: '邮件模板已保存', + placeholderCopied: '占位符已复制', + validationRequired: '主题和 HTML 模板不能为空', + empty: '暂无可用的邮件模板事件或语言。', + noPreview: '刷新预览后查看渲染后的邮件主题。', + customized: '已自定义' + }, opsMonitoring: { title: '运维监控', description: '启用运维监控模块,用于排障与健康可视化', diff --git a/frontend/src/views/KeyUsageView.vue b/frontend/src/views/KeyUsageView.vue index 21a35340..c3a303d4 100644 --- a/frontend/src/views/KeyUsageView.vue +++ b/frontend/src/views/KeyUsageView.vue @@ -289,6 +289,62 @@ + +
+
+

{{ t('keyUsage.dailyDetail') }}

+
+ +
+
+
+ + + + + + + + + + + + + + + + + + + + + + + +
{{ t('keyUsage.date') }}{{ t('keyUsage.requests') }}{{ t('keyUsage.inputTokens') }}{{ t('keyUsage.outputTokens') }}{{ t('keyUsage.cacheReadTokens') }}{{ t('keyUsage.cacheWriteTokens') }}{{ t('keyUsage.cost') }}
{{ row.date }}{{ fmtNum(row.requests) }}{{ fmtNum(row.input_tokens) }}{{ fmtNum(row.output_tokens) }}{{ fmtNum(row.cache_read_tokens) }}{{ fmtNum(row.cache_write_tokens) }}{{ usd(row.actual_cost != null ? row.actual_cost : row.cost) }}
+
+
+ {{ t('keyUsage.noDailyUsage') }} +
+
+
('today') const customStartDate = ref('') const customEndDate = ref('') +const dailyUsageDays = ref<7 | 30 | 90>(30) const dateRanges = computed(() => [ { key: 'today' as const, label: t('keyUsage.dateRangeToday') }, @@ -416,6 +473,12 @@ const dateRanges = computed(() => [ { key: 'custom' as const, label: t('keyUsage.dateRangeCustom') }, ]) +const dailyUsageOptions = computed(() => [ + { value: 7 as const, label: t('keyUsage.dateRange7d') }, + { value: 30 as const, label: t('keyUsage.dateRange30d') }, + { value: 90 as const, label: t('keyUsage.dateRange90d') }, +]) + function setDateRange(key: DateRangeKey) { currentRange.value = key if (key !== 'custom') { @@ -426,23 +489,36 @@ function setDateRange(key: DateRangeKey) { function getDateParams(): string { const now = new Date() const fmt = (d: Date) => d.toISOString().split('T')[0] + const params = new URLSearchParams() if (currentRange.value === 'custom') { if (customStartDate.value && customEndDate.value) { - return `start_date=${customStartDate.value}&end_date=${customEndDate.value}` + params.set('start_date', customStartDate.value) + params.set('end_date', customEndDate.value) } - return '' + } else { + const end = fmt(now) + let start: string + switch (currentRange.value) { + case 'today': start = end; break + case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break + case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break + default: start = fmt(new Date(now.getTime() - 30 * 86400000)) + } + params.set('start_date', start) + params.set('end_date', end) } + params.set('days', String(dailyUsageDays.value)) + params.set('timezone', getBrowserTimezone()) + return params.toString() +} - const end = fmt(now) - let start: string - switch (currentRange.value) { - case 'today': start = end; break - case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break - case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break - default: start = fmt(new Date(now.getTime() - 30 * 86400000)) +function setDailyUsageDays(days: 7 | 30 | 90) { + if (dailyUsageDays.value === days) return + dailyUsageDays.value = days + if (resultData.value && apiKey.value.trim()) { + queryKey() } - return `start_date=${start}&end_date=${end}` } // ==================== Ring Animation ==================== @@ -731,6 +807,24 @@ const usageStatCells = computed(() => { // eslint-disable-next-line @typescript-eslint/no-explicit-any const modelStats = computed(() => resultData.value?.model_stats || []) +interface DailyUsageRow { + date: string + requests: number + input_tokens: number + output_tokens: number + cache_read_tokens: number + cache_write_tokens: number + cost: number + actual_cost?: number +} + +const dailyUsageRows = computed(() => { + const rows = resultData.value?.daily_usage + return Array.isArray(rows) ? rows : [] +}) + +const showDailyUsage = computed(() => Boolean(resultData.value && Array.isArray(resultData.value.daily_usage))) + // ==================== Utility Functions ==================== function usd(value: number | null | undefined): string { @@ -750,6 +844,14 @@ function formatDate(iso: string | null | undefined): string { return d.toLocaleDateString(loc, { year: 'numeric', month: 'long', day: 'numeric' }) } +function getBrowserTimezone(): string { + try { + return Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC' + } catch { + return 'UTC' + } +} + // ==================== API Query ==================== async function fetchUsage(key: string) { diff --git a/frontend/src/views/__tests__/KeyUsageView.spec.ts b/frontend/src/views/__tests__/KeyUsageView.spec.ts new file mode 100644 index 00000000..c1373bc3 --- /dev/null +++ b/frontend/src/views/__tests__/KeyUsageView.spec.ts @@ -0,0 +1,208 @@ +import { describe, expect, it, beforeEach, afterEach, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' +import { nextTick } from 'vue' + +import KeyUsageView from '../KeyUsageView.vue' + +const { showInfo, showSuccess, showError, fetchPublicSettings } = vi.hoisted(() => ({ + showInfo: vi.fn(), + showSuccess: vi.fn(), + showError: vi.fn(), + fetchPublicSettings: vi.fn(), +})) + +const messages: Record = { + 'keyUsage.title': 'API Key Usage', + 'keyUsage.subtitle': 'Usage status', + 'keyUsage.placeholder': 'sk-test', + 'keyUsage.query': 'Query', + 'keyUsage.querying': 'Querying...', + 'keyUsage.privacyNote': 'Privacy note', + 'keyUsage.dateRange': 'Date Range:', + 'keyUsage.dateRangeToday': 'Today', + 'keyUsage.dateRange7d': '7 Days', + 'keyUsage.dateRange30d': '30 Days', + 'keyUsage.dateRange90d': '90 Days', + 'keyUsage.dateRangeCustom': 'Custom', + 'keyUsage.apply': 'Apply', + 'keyUsage.used': 'Used', + 'keyUsage.detailInfo': 'Detail Information', + 'keyUsage.tokenStats': 'Token Statistics', + 'keyUsage.dailyDetail': 'Daily Detail', + 'keyUsage.date': 'Date', + 'keyUsage.requests': 'Requests', + 'keyUsage.inputTokens': 'Input Tokens', + 'keyUsage.outputTokens': 'Output Tokens', + 'keyUsage.cacheReadTokens': 'Cache Read', + 'keyUsage.cacheWriteTokens': 'Cache Write', + 'keyUsage.cost': 'Cost', + 'keyUsage.quotaMode': 'Key Quota Mode', + 'keyUsage.walletBalance': 'Wallet Balance', + 'keyUsage.totalQuota': 'Total Quota', + 'keyUsage.limit5h': '5-Hour Limit', + 'keyUsage.limitDaily': 'Daily Limit', + 'keyUsage.limit7d': '7-Day Limit', + 'keyUsage.limitWeekly': 'Weekly Limit', + 'keyUsage.limitMonthly': 'Monthly Limit', + 'keyUsage.remainingQuota': 'Remaining Quota', + 'keyUsage.usedQuota': 'Used Quota', + 'keyUsage.subscriptionType': 'Subscription Type', + 'keyUsage.todayRequests': 'Today Requests', + 'keyUsage.todayInputTokens': 'Today Input', + 'keyUsage.todayOutputTokens': 'Today Output', + 'keyUsage.todayTokens': 'Today Tokens', + 'keyUsage.todayCacheCreation': 'Today Cache Creation', + 'keyUsage.todayCacheRead': 'Today Cache Read', + 'keyUsage.todayCost': 'Today Cost', + 'keyUsage.rpmTpm': 'RPM / TPM', + 'keyUsage.totalRequests': 'Total Requests', + 'keyUsage.totalInputTokens': 'Total Input', + 'keyUsage.totalOutputTokens': 'Total Output', + 'keyUsage.totalTokensLabel': 'Total Tokens', + 'keyUsage.totalCacheCreation': 'Total Cache Creation', + 'keyUsage.totalCacheRead': 'Total Cache Read', + 'keyUsage.totalCost': 'Total Cost', + 'keyUsage.avgDuration': 'Avg Duration', + 'keyUsage.querySuccess': 'Query successful', + 'keyUsage.queryFailed': 'Query failed', + 'keyUsage.queryFailedRetry': 'Query failed, please try again later', + 'home.viewDocs': 'Docs', + 'home.switchToLight': 'Light', + 'home.switchToDark': 'Dark', + 'home.footer.allRightsReserved': 'All rights reserved.', +} + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => messages[key] ?? key, + locale: { value: 'en' }, + }), + } +}) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + cachedPublicSettings: null, + siteName: 'Sub2API', + siteLogo: '', + docUrl: '', + publicSettingsLoaded: true, + fetchPublicSettings, + showInfo, + showSuccess, + showError, + }), +})) + +describe('KeyUsageView daily detail', () => { + beforeEach(() => { + showInfo.mockReset() + showSuccess.mockReset() + showError.mockReset() + fetchPublicSettings.mockReset() + localStorage.clear() + + Object.defineProperty(window, 'matchMedia', { + configurable: true, + value: vi.fn().mockReturnValue({ matches: false }), + }) + vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => window.setTimeout(() => cb(0), 0)) + vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + mode: 'quota_limited', + isValid: true, + status: 'active', + quota: { + limit: 10, + used: 1, + remaining: 9, + unit: 'USD', + }, + usage: { + today: { + requests: 1, + input_tokens: 10, + output_tokens: 20, + cache_creation_tokens: 0, + cache_read_tokens: 0, + total_tokens: 30, + actual_cost: 0.01, + }, + total: { + requests: 12, + input_tokens: 100, + output_tokens: 200, + cache_creation_tokens: 10, + cache_read_tokens: 30, + total_tokens: 340, + actual_cost: 0.12, + }, + rpm: 0, + tpm: 0, + }, + daily_usage: [ + { + date: '2026-05-19', + requests: 12, + input_tokens: 100, + output_tokens: 200, + cache_read_tokens: 30, + cache_write_tokens: 10, + total_tokens: 340, + cost: 0.15, + actual_cost: 0.12, + }, + ], + }), + })) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('renders daily usage detail rows after a successful query', async () => { + const wrapper = mount(KeyUsageView, { + global: { + stubs: { + RouterLink: { template: '' }, + LocaleSwitcher: true, + Icon: true, + }, + }, + }) + + await wrapper.find('input').setValue('sk-test-key') + await wrapper.find('input').trigger('keydown.enter') + await flushPromises() + await nextTick() + + const fetchMock = vi.mocked(fetch) + expect(fetchMock).toHaveBeenCalledWith( + expect.stringContaining('/v1/usage?'), + expect.objectContaining({ + headers: { Authorization: 'Bearer sk-test-key' }, + }) + ) + expect(String(fetchMock.mock.calls[0][0])).toContain('days=30') + + const text = wrapper.text() + expect(text).toContain('Daily Detail') + expect(text).toContain('Date') + expect(text).toContain('Cache Read') + expect(text).toContain('Cache Write') + expect(text).toContain('2026-05-19') + expect(text).toContain('12') + expect(text).toContain('100') + expect(text).toContain('200') + expect(text).toContain('30') + expect(text).toContain('10') + expect(text).toContain('$0.12') + + wrapper.unmount() + }) +}) diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 505ba4f3..9a648796 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3769,6 +3769,36 @@ }}

+ + +
+ + +

+ {{ + t( + "admin.settings.gatewayForwarding.openaiCodexUserAgentHint", + ) + }} +

+
@@ -6225,6 +6255,9 @@ + + +
({ enable_anthropic_cache_ttl_1h_injection: false, rewrite_message_cache_control: false, antigravity_user_agent_version: "", + openai_codex_user_agent: "", // Balance & quota notification balance_low_notify_enabled: false, balance_low_notify_threshold: 0, @@ -8044,6 +8079,8 @@ async function saveSettings() { rewrite_message_cache_control: form.rewrite_message_cache_control, antigravity_user_agent_version: form.antigravity_user_agent_version?.trim() || "", + openai_codex_user_agent: + form.openai_codex_user_agent?.trim() || "", // Payment configuration payment_enabled: form.payment_enabled, risk_control_enabled: form.risk_control_enabled, diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts index 275e38c5..0d4ab7d2 100644 --- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts +++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts @@ -371,6 +371,7 @@ const baseSettingsResponse = { enable_anthropic_cache_ttl_1h_injection: false, rewrite_message_cache_control: false, antigravity_user_agent_version: "", + openai_codex_user_agent: "", payment_enabled: true, payment_min_amount: 1, payment_max_amount: 10000, diff --git a/frontend/src/views/admin/settings/EmailTemplateEditor.vue b/frontend/src/views/admin/settings/EmailTemplateEditor.vue new file mode 100644 index 00000000..6643f799 --- /dev/null +++ b/frontend/src/views/admin/settings/EmailTemplateEditor.vue @@ -0,0 +1,483 @@ + + + diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index 39568250..8c45a9ff 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -13,6 +13,7 @@ export default defineConfig({ test: { globals: true, environment: 'jsdom', + setupFiles: ['./src/__tests__/setup.ts'], include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'], exclude: ['node_modules', 'dist'], coverage: {