chore: merge upstream Wei-Shaw/sub2api v0.1.128 — keep fork customizations
Upstream 新功能 (34 commits, ~main..origin/main): - feat(email): 通知邮件模板服务、模板编辑器、订阅/余额提醒邮件 - feat(notification): NotificationEmailService 注入到 Balance/Payment/Setting - feat(payment): 支付成功通知邮件 - feat(usage): 用户 API Key 用量页支持按日明细 - feat(openai-gateway): Codex OAuth 浏览器 UA 自动改写规避 Cloudflare 质询 - feat(admin): 邮件模板管理接口 - fix(auth): 停用/删除分组后阻断 API Key - fix(group): 修正分组账号可用计数口径 - fix(openai): /v1/responses respect force chat completions, images n 参数透传 - test(repository): AES Encryptor 单元测试 - chore: VERSION 0.1.128 冲突解决 (backend/cmd/server/wire_gen.go): - 引入 upstream 新 wire providers: notificationEmailService, ProvidePaymentService(10 args), ProvideAdminSettingHandler(8 args) - 保留 fork 独有依赖: rpmTokenBucketService (RPM 平滑), NewOpsHandler 3 参数版本 (requestEventBus, opsLogBroadcaster) - ProvideBalanceNotifyService 接受 4 参数 (含 notificationEmailService) 修复 session-id helper 设计 (claude_code_session_id.go): - 发现: TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata 在 OAuth + mimicClaudeCode=false 场景失败 - 重新审视设计原则: OAuth 凭证本身就是 Claude Code 客户端,可信任 metadata 派生 session_id;不应受 mimicClaudeCode 标志阻止 - 修复: metadata 派生只看 tokenType=="oauth";UUID 兜底仍需 oauth && mimic - 更新测试: OAuthNonMimicDerivesFromMetadata 取代原 IgnoresMetadata 所有 fork 独有功能保留: - Claude Code 2.1.145 mimicry bundle (上个 commit 引入) - RPM token bucket smoothing (commit 95814974) - Windsurf/Antigravity/Omniroute 定制 - claudemask/ 校验包 (upstream 已删除,我们仍在 gateway_service 中使用) 不在范围: - 不修复 baseline 既存的 2 个测试失败 (TestProxyImportData..., TestWindsurfTierAccessService_Snapshot_HappyPath) - 与 merge 无关
This commit is contained in:
commit
92433656f5
@ -1 +1 @@
|
||||
0.1.127
|
||||
0.1.128
|
||||
|
||||
@ -189,7 +189,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
@ -204,8 +205,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||
registry := payment.ProvideRegistry()
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
|
||||
paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService, notificationEmailService)
|
||||
settingHandler := handler.ProvideAdminSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService, notificationEmailService)
|
||||
requestEventBus := service.NewRequestEventBus()
|
||||
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
|
||||
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
|
||||
@ -253,7 +254,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo, notificationEmailService)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
||||
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
|
||||
@ -274,7 +275,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
||||
|
||||
@ -56,13 +56,14 @@ func firstNonEmpty(values ...string) string {
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
paymentConfigService *service.PaymentConfigService
|
||||
paymentService *service.PaymentService
|
||||
userAttributeService *service.UserAttributeService
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
paymentConfigService *service.PaymentConfigService
|
||||
paymentService *service.PaymentService
|
||||
userAttributeService *service.UserAttributeService
|
||||
notificationEmailService *service.NotificationEmailService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
@ -78,6 +79,12 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
|
||||
}
|
||||
}
|
||||
|
||||
// SetNotificationEmailService attaches the notification template service without changing
|
||||
// the constructor signature used by existing unit tests.
|
||||
func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) {
|
||||
h.notificationEmailService = notificationEmailService
|
||||
}
|
||||
|
||||
// GetSettings 获取所有系统设置
|
||||
// GET /api/v1/admin/settings
|
||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
@ -247,6 +254,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
||||
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
|
||||
OpenAICodexUserAgent: settings.OpenAICodexUserAgent,
|
||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||
@ -563,6 +571,7 @@ type UpdateSettingsRequest struct {
|
||||
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
|
||||
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
|
||||
OpenAICodexUserAgent *string `json:"openai_codex_user_agent"`
|
||||
|
||||
// Payment visible method routing
|
||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||
@ -1404,6 +1413,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.OpenAICodexUserAgent != nil {
|
||||
normalized := strings.TrimSpace(*req.OpenAICodexUserAgent)
|
||||
req.OpenAICodexUserAgent = &normalized
|
||||
// 仅做长度上限保护,不限制具体格式(运维需要可自由调整 codex 版本号)
|
||||
if len(normalized) > 512 {
|
||||
response.Error(c, http.StatusBadRequest, "openai_codex_user_agent must be at most 512 characters")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
|
||||
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
|
||||
@ -1597,6 +1615,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.AntigravityUserAgentVersion
|
||||
}(),
|
||||
OpenAICodexUserAgent: func() string {
|
||||
if req.OpenAICodexUserAgent != nil {
|
||||
return *req.OpenAICodexUserAgent
|
||||
}
|
||||
return previousSettings.OpenAICodexUserAgent
|
||||
}(),
|
||||
PaymentVisibleMethodAlipaySource: func() string {
|
||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||
@ -1956,6 +1980,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
||||
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
|
||||
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
|
||||
OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent,
|
||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||
@ -2411,6 +2436,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion {
|
||||
changed = append(changed, "antigravity_user_agent_version")
|
||||
}
|
||||
if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent {
|
||||
changed = append(changed, "openai_codex_user_agent")
|
||||
}
|
||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||
changed = append(changed, "payment_visible_method_alipay_source")
|
||||
}
|
||||
@ -3339,3 +3367,160 @@ func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key,
|
||||
}
|
||||
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
|
||||
}
|
||||
|
||||
// ListEmailTemplates returns all editable notification email templates.
|
||||
// GET /api/v1/admin/settings/email-templates
|
||||
func (h *SettingHandler) ListEmailTemplates(c *gin.Context) {
|
||||
if h.notificationEmailService == nil {
|
||||
response.InternalError(c, "notification email service is not configured")
|
||||
return
|
||||
}
|
||||
events := h.notificationEmailService.ListEventInfos()
|
||||
templates, err := h.notificationEmailService.ListTemplates(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.EmailTemplateListResponse{
|
||||
Events: emailTemplateEventOptionsToDTO(events),
|
||||
Locales: h.notificationEmailService.SupportedLocales(),
|
||||
Templates: emailTemplateSummariesToDTO(templates),
|
||||
Placeholders: emailTemplatePlaceholderUnion(events),
|
||||
})
|
||||
}
|
||||
|
||||
// GetEmailTemplate returns one editable notification email template.
|
||||
// GET /api/v1/admin/settings/email-templates/:event/:locale
|
||||
func (h *SettingHandler) GetEmailTemplate(c *gin.Context) {
|
||||
if h.notificationEmailService == nil {
|
||||
response.InternalError(c, "notification email service is not configured")
|
||||
return
|
||||
}
|
||||
tmpl, err := h.notificationEmailService.GetTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"))
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, emailTemplateDetailToDTO(tmpl))
|
||||
}
|
||||
|
||||
// UpdateEmailTemplate saves an override for one event/locale template.
|
||||
// PUT /api/v1/admin/settings/email-templates/:event/:locale
|
||||
func (h *SettingHandler) UpdateEmailTemplate(c *gin.Context) {
|
||||
if h.notificationEmailService == nil {
|
||||
response.InternalError(c, "notification email service is not configured")
|
||||
return
|
||||
}
|
||||
var req dto.UpdateEmailTemplateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
tmpl, err := h.notificationEmailService.UpdateTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"), req.Subject, req.HTML)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, emailTemplateDetailToDTO(tmpl))
|
||||
}
|
||||
|
||||
// RestoreOfficialEmailTemplate removes an override and returns the built-in template.
|
||||
// POST /api/v1/admin/settings/email-templates/:event/:locale/restore-official
|
||||
func (h *SettingHandler) RestoreOfficialEmailTemplate(c *gin.Context) {
|
||||
if h.notificationEmailService == nil {
|
||||
response.InternalError(c, "notification email service is not configured")
|
||||
return
|
||||
}
|
||||
tmpl, err := h.notificationEmailService.RestoreOfficialTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"))
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, emailTemplateDetailToDTO(tmpl))
|
||||
}
|
||||
|
||||
// PreviewEmailTemplate renders a template with safe sample variables without saving it.
|
||||
// POST /api/v1/admin/settings/email-templates/preview
|
||||
func (h *SettingHandler) PreviewEmailTemplate(c *gin.Context) {
|
||||
if h.notificationEmailService == nil {
|
||||
response.InternalError(c, "notification email service is not configured")
|
||||
return
|
||||
}
|
||||
var req dto.PreviewEmailTemplateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
preview, err := h.notificationEmailService.PreviewTemplate(c.Request.Context(), service.NotificationEmailPreviewInput{
|
||||
Event: req.Event,
|
||||
Locale: req.Locale,
|
||||
Subject: req.Subject,
|
||||
HTML: req.HTML,
|
||||
Variables: req.Variables,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, dto.EmailTemplatePreviewResponse{Subject: preview.Subject, HTML: preview.HTML})
|
||||
}
|
||||
|
||||
func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo) []dto.EmailTemplateEventOption {
|
||||
items := make([]dto.EmailTemplateEventOption, 0, len(events))
|
||||
for _, event := range events {
|
||||
items = append(items, dto.EmailTemplateEventOption{
|
||||
Value: event.Event,
|
||||
Label: event.Label,
|
||||
Description: event.Description,
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func emailTemplateSummariesToDTO(templates []service.NotificationEmailTemplate) []dto.EmailTemplateSummary {
|
||||
items := make([]dto.EmailTemplateSummary, 0, len(templates))
|
||||
for _, tmpl := range templates {
|
||||
items = append(items, dto.EmailTemplateSummary{
|
||||
Event: tmpl.Event,
|
||||
Locale: tmpl.Locale,
|
||||
Subject: tmpl.Subject,
|
||||
IsCustom: tmpl.IsCustom,
|
||||
UpdatedAt: emailTemplateUpdatedAt(tmpl),
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func emailTemplateDetailToDTO(tmpl service.NotificationEmailTemplate) dto.EmailTemplateDetail {
|
||||
return dto.EmailTemplateDetail{
|
||||
Event: tmpl.Event,
|
||||
Locale: tmpl.Locale,
|
||||
Subject: tmpl.Subject,
|
||||
HTML: tmpl.HTML,
|
||||
IsCustom: tmpl.IsCustom,
|
||||
UpdatedAt: emailTemplateUpdatedAt(tmpl),
|
||||
Placeholders: tmpl.Placeholders,
|
||||
}
|
||||
}
|
||||
|
||||
func emailTemplateUpdatedAt(tmpl service.NotificationEmailTemplate) string {
|
||||
if tmpl.UpdatedAt == nil {
|
||||
return ""
|
||||
}
|
||||
return tmpl.UpdatedAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
}
|
||||
|
||||
func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) []string {
|
||||
seen := make(map[string]struct{})
|
||||
placeholders := make([]string, 0)
|
||||
for _, event := range events {
|
||||
for _, placeholder := range event.Placeholders {
|
||||
if _, ok := seen[placeholder]; ok {
|
||||
continue
|
||||
}
|
||||
seen[placeholder] = struct{}{}
|
||||
placeholders = append(placeholders, placeholder)
|
||||
}
|
||||
}
|
||||
return placeholders
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email, c.GetHeader("Accept-Language"))
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@ -602,7 +602,7 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL, c.GetHeader("Accept-Language")); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -545,7 +545,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
|
||||
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email, c.GetHeader("Accept-Language"))
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -181,6 +181,7 @@ type SystemSettings struct {
|
||||
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
|
||||
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
|
||||
OpenAICodexUserAgent string `json:"openai_codex_user_agent"`
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||
@ -377,6 +378,62 @@ type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// EmailTemplateEventOption describes an editable notification email event.
|
||||
type EmailTemplateEventOption struct {
|
||||
Value string `json:"value"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// EmailTemplateSummary is shown in the admin email template list.
|
||||
type EmailTemplateSummary struct {
|
||||
Event string `json:"event"`
|
||||
Locale string `json:"locale"`
|
||||
Subject string `json:"subject"`
|
||||
IsCustom bool `json:"is_custom,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// EmailTemplateListResponse is returned by GET /admin/settings/email-templates.
|
||||
type EmailTemplateListResponse struct {
|
||||
Events []EmailTemplateEventOption `json:"events"`
|
||||
Locales []string `json:"locales"`
|
||||
Templates []EmailTemplateSummary `json:"templates,omitempty"`
|
||||
Placeholders []string `json:"placeholders,omitempty"`
|
||||
}
|
||||
|
||||
// EmailTemplateDetail is returned for a specific event/locale template.
|
||||
type EmailTemplateDetail struct {
|
||||
Event string `json:"event"`
|
||||
Locale string `json:"locale"`
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
IsCustom bool `json:"is_custom,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
Placeholders []string `json:"placeholders,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateEmailTemplateRequest updates a template override.
|
||||
type UpdateEmailTemplateRequest struct {
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
}
|
||||
|
||||
// PreviewEmailTemplateRequest previews a template without saving it.
|
||||
type PreviewEmailTemplateRequest struct {
|
||||
Event string `json:"event"`
|
||||
Locale string `json:"locale"`
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
Variables map[string]string `json:"variables,omitempty"`
|
||||
}
|
||||
|
||||
// EmailTemplatePreviewResponse is the rendered preview payload.
|
||||
type EmailTemplatePreviewResponse struct {
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@ -1133,9 +1133,15 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
|
||||
// 解析可选的日期范围参数(用于 model_stats 查询)
|
||||
startTime, endTime := h.parseUsageDateRange(c)
|
||||
days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", ""))
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Invalid days, allowed range is 1-90")
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
|
||||
usageData := h.buildUsageData(ctx, apiKey.ID)
|
||||
dailyUsage := h.buildAPIKeyDailyUsage(c, subject.UserID, apiKey.ID, days)
|
||||
|
||||
// Best-effort: 获取模型统计
|
||||
var modelStats any
|
||||
@ -1149,11 +1155,11 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
|
||||
|
||||
if isQuotaLimited {
|
||||
h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats)
|
||||
h.usageQuotaLimited(c, ctx, apiKey, usageData, dailyUsage, modelStats)
|
||||
return
|
||||
}
|
||||
|
||||
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats)
|
||||
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, dailyUsage, modelStats)
|
||||
}
|
||||
|
||||
// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围
|
||||
@ -1211,8 +1217,20 @@ func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) buildAPIKeyDailyUsage(c *gin.Context, userID, apiKeyID int64, days int) any {
|
||||
if h.usageService == nil {
|
||||
return nil
|
||||
}
|
||||
startTime, endTime := apiKeyDailyUsageRange(days, c.Query("timezone"))
|
||||
stats, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), userID, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// usageQuotaLimited 处理 quota_limited 模式的响应
|
||||
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) {
|
||||
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, dailyUsage any, modelStats any) {
|
||||
resp := gin.H{
|
||||
"mode": "quota_limited",
|
||||
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
|
||||
@ -1294,6 +1312,9 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
if dailyUsage != nil {
|
||||
resp["daily_usage"] = dailyUsage
|
||||
}
|
||||
if modelStats != nil {
|
||||
resp["model_stats"] = modelStats
|
||||
}
|
||||
@ -1302,7 +1323,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
||||
}
|
||||
|
||||
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
|
||||
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) {
|
||||
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, dailyUsage any, modelStats any) {
|
||||
// 订阅模式
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
resp := gin.H{
|
||||
@ -1331,6 +1352,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context,
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
if dailyUsage != nil {
|
||||
resp["daily_usage"] = dailyUsage
|
||||
}
|
||||
if modelStats != nil {
|
||||
resp["model_stats"] = modelStats
|
||||
}
|
||||
@ -1356,6 +1380,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context,
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
if dailyUsage != nil {
|
||||
resp["daily_usage"] = dailyUsage
|
||||
}
|
||||
if modelStats != nil {
|
||||
resp["model_stats"] = modelStats
|
||||
}
|
||||
|
||||
@ -266,6 +266,7 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
||||
PaymentSource: req.PaymentSource,
|
||||
OrderType: req.OrderType,
|
||||
PlanID: req.PlanID,
|
||||
Locale: c.GetHeader("Accept-Language"),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"html"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@ -10,8 +14,9 @@ import (
|
||||
|
||||
// SettingHandler 公开设置处理器(无需认证)
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
version string
|
||||
settingService *service.SettingService
|
||||
notificationEmailService *service.NotificationEmailService
|
||||
version string
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建公开设置处理器
|
||||
@ -22,6 +27,12 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
|
||||
}
|
||||
}
|
||||
|
||||
// SetNotificationEmailService attaches the public notification email service without
|
||||
// changing the constructor signature used by existing tests.
|
||||
func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) {
|
||||
h.notificationEmailService = notificationEmailService
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置
|
||||
// GET /api/v1/settings/public
|
||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
@ -90,6 +101,27 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// UnsubscribeNotificationEmail handles optional notification email opt-outs.
|
||||
// GET /api/v1/settings/email-unsubscribe?token=...
|
||||
func (h *SettingHandler) UnsubscribeNotificationEmail(c *gin.Context) {
|
||||
if h.notificationEmailService == nil {
|
||||
response.InternalError(c, "notification email service is not configured")
|
||||
return
|
||||
}
|
||||
token := strings.TrimSpace(c.Query("token"))
|
||||
if token == "" {
|
||||
response.BadRequest(c, "token is required")
|
||||
return
|
||||
}
|
||||
result, err := h.notificationEmailService.Unsubscribe(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
body := "<!doctype html><html><head><meta charset=\"utf-8\"><title>Unsubscribed</title></head><body style=\"font-family:-apple-system,BlinkMacSystemFont,Segoe UI,sans-serif;padding:32px;\"><h1>Unsubscribed</h1><p>You have unsubscribed <strong>" + html.EscapeString(result.Email) + "</strong> from <strong>" + html.EscapeString(result.Event) + "</strong> emails.</p></body></html>"
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(body))
|
||||
}
|
||||
|
||||
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
||||
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
||||
for _, item := range items {
|
||||
|
||||
@ -172,7 +172,7 @@ func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID, c.GetHeader("Accept-Language")); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -298,6 +298,29 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
const (
|
||||
defaultAPIKeyDailyUsageDays = 30
|
||||
maxAPIKeyDailyUsageDays = 90
|
||||
)
|
||||
|
||||
func parseAPIKeyDailyUsageDays(raw string) (int, bool) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return defaultAPIKeyDailyUsageDays, true
|
||||
}
|
||||
days, err := strconv.Atoi(raw)
|
||||
if err != nil || days <= 0 || days > maxAPIKeyDailyUsageDays {
|
||||
return 0, false
|
||||
}
|
||||
return days, true
|
||||
}
|
||||
|
||||
func apiKeyDailyUsageRange(days int, userTZ string) (time.Time, time.Time) {
|
||||
now := timezone.NowInUserLocation(userTZ)
|
||||
startTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -(days-1)), userTZ)
|
||||
endTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// DashboardStats handles getting user dashboard statistics
|
||||
// GET /api/v1/usage/dashboard/stats
|
||||
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
@ -416,3 +439,55 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// GetMyAPIKeyDailyUsage handles getting daily usage details for the current user's API key.
|
||||
// GET /api/v1/user/api-keys/:id/usage/daily?days=30
|
||||
func (h *UsageHandler) GetMyAPIKeyDailyUsage(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid API key ID")
|
||||
return
|
||||
}
|
||||
|
||||
days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", ""))
|
||||
if !ok {
|
||||
response.BadRequest(c, "Invalid days, allowed range is 1-90")
|
||||
return
|
||||
}
|
||||
|
||||
if h.apiKeyService == nil {
|
||||
response.InternalError(c, "API key service is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), apiKeyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's usage")
|
||||
return
|
||||
}
|
||||
|
||||
userTZ := c.Query("timezone")
|
||||
startTime, endTime := apiKeyDailyUsageRange(days, userTZ)
|
||||
items, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), subject.UserID, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"items": items,
|
||||
"days": days,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.AddDate(0, 0, -1).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
195
backend/internal/handler/usage_handler_daily_test.go
Normal file
195
backend/internal/handler/usage_handler_daily_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dailyUsageRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
trend []usagestats.TrendDataPoint
|
||||
|
||||
called bool
|
||||
startTime time.Time
|
||||
endTime time.Time
|
||||
granularity string
|
||||
userID int64
|
||||
apiKeyID int64
|
||||
}
|
||||
|
||||
func (s *dailyUsageRepoStub) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
s.called = true
|
||||
s.startTime = startTime
|
||||
s.endTime = endTime
|
||||
s.granularity = granularity
|
||||
s.userID = userID
|
||||
s.apiKeyID = apiKeyID
|
||||
return s.trend, nil
|
||||
}
|
||||
|
||||
type dailyUsageAPIKeyRepoStub struct {
|
||||
service.APIKeyRepository
|
||||
keys map[int64]*service.APIKey
|
||||
}
|
||||
|
||||
func (s *dailyUsageAPIKeyRepoStub) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
key, ok := s.keys[id]
|
||||
if !ok {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *key
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func newDailyUsageTestRouter(usageRepo *dailyUsageRepoStub, apiKeyRepo *dailyUsageAPIKeyRepoStub, userID int64) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(usageRepo, nil, nil, nil)
|
||||
apiKeySvc := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, apiKeySvc)
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/user/api-keys/:id/usage/daily", handler.GetMyAPIKeyDailyUsage)
|
||||
return router
|
||||
}
|
||||
|
||||
type dailyUsageHandlerResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Items []usagestats.APIKeyDailyUsagePoint `json:"items"`
|
||||
Days int `json:"days"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func TestGetMyAPIKeyDailyUsageRejectsCrossUserAccess(t *testing.T) {
|
||||
usageRepo := &dailyUsageRepoStub{}
|
||||
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||
keys: map[int64]*service.APIKey{
|
||||
7: {ID: 7, UserID: 99, Status: service.StatusAPIKeyActive},
|
||||
},
|
||||
}
|
||||
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=30", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.False(t, usageRepo.called)
|
||||
}
|
||||
|
||||
func TestGetMyAPIKeyDailyUsageRejectsInvalidDays(t *testing.T) {
|
||||
for _, path := range []string{
|
||||
"/user/api-keys/7/usage/daily?days=0",
|
||||
"/user/api-keys/7/usage/daily?days=91",
|
||||
} {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
usageRepo := &dailyUsageRepoStub{}
|
||||
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||
keys: map[int64]*service.APIKey{
|
||||
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
|
||||
},
|
||||
}
|
||||
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.False(t, usageRepo.called)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyAPIKeyDailyUsageReturnsEmptyData(t *testing.T) {
|
||||
usageRepo := &dailyUsageRepoStub{trend: []usagestats.TrendDataPoint{}}
|
||||
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||
keys: map[int64]*service.APIKey{
|
||||
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
|
||||
},
|
||||
}
|
||||
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var got dailyUsageHandlerResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, 30, got.Data.Days)
|
||||
require.Empty(t, got.Data.Items)
|
||||
}
|
||||
|
||||
func TestGetMyAPIKeyDailyUsageAggregatesByDayForOwnedKey(t *testing.T) {
|
||||
usageRepo := &dailyUsageRepoStub{
|
||||
trend: []usagestats.TrendDataPoint{
|
||||
{
|
||||
Date: "2026-05-19",
|
||||
Requests: 3,
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 4,
|
||||
CacheReadTokens: 6,
|
||||
TotalTokens: 40,
|
||||
Cost: 0.5,
|
||||
ActualCost: 0.4,
|
||||
},
|
||||
},
|
||||
}
|
||||
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||
keys: map[int64]*service.APIKey{
|
||||
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
|
||||
},
|
||||
}
|
||||
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=7", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.True(t, usageRepo.called)
|
||||
require.Equal(t, "day", usageRepo.granularity)
|
||||
require.Equal(t, int64(42), usageRepo.userID)
|
||||
require.Equal(t, int64(7), usageRepo.apiKeyID)
|
||||
require.True(t, usageRepo.startTime.Before(usageRepo.endTime))
|
||||
|
||||
var got dailyUsageHandlerResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, 7, got.Data.Days)
|
||||
require.Len(t, got.Data.Items, 1)
|
||||
require.Equal(t, usagestats.APIKeyDailyUsagePoint{
|
||||
Date: "2026-05-19",
|
||||
Requests: 3,
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheReadTokens: 6,
|
||||
CacheWriteTokens: 4,
|
||||
TotalTokens: 40,
|
||||
Cost: 0.5,
|
||||
ActualCost: 0.4,
|
||||
}, got.Data.Items[0])
|
||||
}
|
||||
@ -335,7 +335,7 @@ func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
|
||||
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email, c.GetHeader("Accept-Language")); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@ -363,7 +363,7 @@ func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
|
||||
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache, c.GetHeader("Accept-Language"))
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -90,8 +90,17 @@ func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
|
||||
return NewSettingHandler(settingService, buildInfo.Version)
|
||||
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo, notificationEmailService *service.NotificationEmailService) *SettingHandler {
|
||||
h := NewSettingHandler(settingService, buildInfo.Version)
|
||||
h.SetNotificationEmailService(notificationEmailService)
|
||||
return h
|
||||
}
|
||||
|
||||
// ProvideAdminSettingHandler creates admin.SettingHandler with notification template APIs.
|
||||
func ProvideAdminSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService, notificationEmailService *service.NotificationEmailService) *admin.SettingHandler {
|
||||
h := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
|
||||
h.SetNotificationEmailService(notificationEmailService)
|
||||
return h
|
||||
}
|
||||
|
||||
// ProvideHandlers creates the Handlers struct
|
||||
@ -169,7 +178,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSettingHandler,
|
||||
ProvideAdminSettingHandler,
|
||||
admin.NewOpsHandler,
|
||||
ProvideSystemHandler,
|
||||
admin.NewSubscriptionHandler,
|
||||
|
||||
@ -0,0 +1,719 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResponsesToChatCompletionsRequest converts a Responses API request into a
|
||||
// Chat Completions request for upstreams that only implement
|
||||
// /v1/chat/completions.
|
||||
func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("responses request is nil")
|
||||
}
|
||||
|
||||
messages, err := responsesInputToChatMessages(req.Instructions, req.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &ChatCompletionsRequest{
|
||||
Model: req.Model,
|
||||
Messages: messages,
|
||||
MaxCompletionTokens: req.MaxOutputTokens,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
ServiceTier: req.ServiceTier,
|
||||
}
|
||||
if req.Reasoning != nil {
|
||||
out.ReasoningEffort = req.Reasoning.Effort
|
||||
}
|
||||
if len(req.Tools) > 0 {
|
||||
out.Tools = responsesToolsToChatTools(req.Tools)
|
||||
}
|
||||
if len(req.ToolChoice) > 0 {
|
||||
out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) {
|
||||
var messages []ChatMessage
|
||||
if strings.TrimSpace(instructions) != "" {
|
||||
content, _ := json.Marshal(instructions)
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "system",
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
inputRaw = bytesTrimSpace(inputRaw)
|
||||
if len(inputRaw) == 0 || string(inputRaw) == "null" {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
var inputText string
|
||||
if err := json.Unmarshal(inputRaw, &inputText); err == nil {
|
||||
content, _ := json.Marshal(inputText)
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
var rawItems []json.RawMessage
|
||||
if err := json.Unmarshal(inputRaw, &rawItems); err != nil {
|
||||
return nil, fmt.Errorf("parse responses input: %w", err)
|
||||
}
|
||||
|
||||
for _, raw := range rawItems {
|
||||
raw = bytesTrimSpace(raw)
|
||||
if len(raw) == 0 || string(raw) == "null" {
|
||||
continue
|
||||
}
|
||||
|
||||
var item map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &item); err != nil {
|
||||
var text string
|
||||
if textErr := json.Unmarshal(raw, &text); textErr == nil {
|
||||
content, _ := json.Marshal(text)
|
||||
messages = append(messages, ChatMessage{Role: "user", Content: content})
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("parse responses input item: %w", err)
|
||||
}
|
||||
|
||||
role := rawString(item["role"])
|
||||
itemType := rawString(item["type"])
|
||||
switch itemType {
|
||||
case "function_call":
|
||||
arguments := rawString(item["arguments"])
|
||||
if strings.TrimSpace(arguments) == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ChatToolCall{{
|
||||
ID: rawString(item["call_id"]),
|
||||
Type: "function",
|
||||
Function: ChatFunctionCall{
|
||||
Name: rawString(item["name"]),
|
||||
Arguments: arguments,
|
||||
},
|
||||
}},
|
||||
})
|
||||
continue
|
||||
case "function_call_output":
|
||||
content, _ := json.Marshal(rawString(item["output"]))
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: rawString(item["call_id"]),
|
||||
Content: content,
|
||||
})
|
||||
continue
|
||||
case "input_text", "text":
|
||||
content, _ := json.Marshal(rawString(item["text"]))
|
||||
messages = append(messages, ChatMessage{Role: "user", Content: content})
|
||||
continue
|
||||
case "input_image":
|
||||
content, err := chatContentFromSingleResponsesPart(itemType, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, ChatMessage{Role: "user", Content: content})
|
||||
continue
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
content := item["content"]
|
||||
if len(bytesTrimSpace(content)) == 0 {
|
||||
if text := rawString(item["text"]); text != "" {
|
||||
content, _ = json.Marshal(text)
|
||||
}
|
||||
}
|
||||
chatContent, err := responsesContentToChatContent(content, role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: role,
|
||||
Content: chatContent,
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) {
|
||||
raw = bytesTrimSpace(raw)
|
||||
if len(raw) == 0 || string(raw) == "null" {
|
||||
empty, _ := json.Marshal("")
|
||||
return empty, nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(raw, &text); err == nil {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
var rawParts []json.RawMessage
|
||||
if err := json.Unmarshal(raw, &rawParts); err == nil {
|
||||
return responsesContentPartsToChatContent(rawParts, role)
|
||||
}
|
||||
|
||||
var obj map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &obj); err == nil {
|
||||
return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj)
|
||||
}
|
||||
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) {
|
||||
var textParts []string
|
||||
var chatParts []ChatContentPart
|
||||
hasNonText := false
|
||||
|
||||
for _, rawPart := range rawParts {
|
||||
var part map[string]json.RawMessage
|
||||
if err := json.Unmarshal(rawPart, &part); err != nil {
|
||||
continue
|
||||
}
|
||||
partType := rawString(part["type"])
|
||||
switch partType {
|
||||
case "input_text", "output_text", "text", "":
|
||||
text := rawString(part["text"])
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
textParts = append(textParts, text)
|
||||
chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text})
|
||||
case "input_image", "image_url":
|
||||
imageURL := rawString(part["image_url"])
|
||||
if imageURL == "" {
|
||||
imageURL = rawNestedString(part["image_url"], "url")
|
||||
}
|
||||
if imageURL == "" {
|
||||
continue
|
||||
}
|
||||
hasNonText = true
|
||||
chatParts = append(chatParts, ChatContentPart{
|
||||
Type: "image_url",
|
||||
ImageURL: &ChatImageURL{URL: imageURL},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if !hasNonText {
|
||||
joined, _ := json.Marshal(strings.Join(textParts, "\n\n"))
|
||||
return joined, nil
|
||||
}
|
||||
if role != "user" {
|
||||
joined, _ := json.Marshal(strings.Join(textParts, "\n\n"))
|
||||
return joined, nil
|
||||
}
|
||||
if len(chatParts) == 0 {
|
||||
empty, _ := json.Marshal("")
|
||||
return empty, nil
|
||||
}
|
||||
return json.Marshal(chatParts)
|
||||
}
|
||||
|
||||
func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) {
|
||||
switch partType {
|
||||
case "input_image", "image_url":
|
||||
imageURL := rawString(part["image_url"])
|
||||
if imageURL == "" {
|
||||
imageURL = rawNestedString(part["image_url"], "url")
|
||||
}
|
||||
return json.Marshal([]ChatContentPart{{
|
||||
Type: "image_url",
|
||||
ImageURL: &ChatImageURL{URL: imageURL},
|
||||
}})
|
||||
default:
|
||||
return json.Marshal(rawString(part["text"]))
|
||||
}
|
||||
}
|
||||
|
||||
func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool {
|
||||
out := make([]ChatTool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
out = append(out, ChatTool{
|
||||
Type: "function",
|
||||
Function: &ChatFunction{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: tool.Parameters,
|
||||
Strict: tool.Strict,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage {
|
||||
var choice map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &choice); err != nil {
|
||||
return raw
|
||||
}
|
||||
if rawString(choice["type"]) != "function" {
|
||||
return raw
|
||||
}
|
||||
name := rawString(choice["name"])
|
||||
if name == "" {
|
||||
name = rawNestedString(choice["function"], "name")
|
||||
}
|
||||
if name == "" {
|
||||
return raw
|
||||
}
|
||||
out, err := json.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]string{
|
||||
"name": name,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions
|
||||
// response into a Responses API response.
|
||||
func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse {
|
||||
id := ""
|
||||
if resp != nil {
|
||||
id = resp.ID
|
||||
}
|
||||
if id == "" {
|
||||
id = generateResponsesID()
|
||||
}
|
||||
|
||||
out := &ResponsesResponse{
|
||||
ID: id,
|
||||
Object: "response",
|
||||
Model: model,
|
||||
Status: "completed",
|
||||
}
|
||||
if resp == nil {
|
||||
out.Output = []ResponsesOutput{emptyResponsesMessageOutput()}
|
||||
return out
|
||||
}
|
||||
if out.Model == "" {
|
||||
out.Model = resp.Model
|
||||
}
|
||||
|
||||
if len(resp.Choices) > 0 {
|
||||
choice := resp.Choices[0]
|
||||
out.Output = chatMessageToResponsesOutput(choice.Message)
|
||||
if choice.FinishReason == "length" {
|
||||
out.Status = "incomplete"
|
||||
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
|
||||
}
|
||||
}
|
||||
if len(out.Output) == 0 {
|
||||
out.Output = []ResponsesOutput{emptyResponsesMessageOutput()}
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
out.Usage = ChatUsageToResponsesUsage(resp.Usage)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput {
|
||||
var outputs []ResponsesOutput
|
||||
if message.ReasoningContent != "" {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "reasoning",
|
||||
ID: generateItemID(),
|
||||
Summary: []ResponsesSummary{{
|
||||
Type: "summary_text",
|
||||
Text: message.ReasoningContent,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
text := chatMessageContentText(message.Content)
|
||||
if text != "" || len(message.ToolCalls) == 0 {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: generateItemID(),
|
||||
Role: "assistant",
|
||||
Content: []ResponsesContentPart{{
|
||||
Type: "output_text",
|
||||
Text: text,
|
||||
}},
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
arguments := toolCall.Function.Arguments
|
||||
if strings.TrimSpace(arguments) == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "function_call",
|
||||
ID: generateItemID(),
|
||||
CallID: toolCall.ID,
|
||||
Name: toolCall.Function.Name,
|
||||
Arguments: arguments,
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
func emptyResponsesMessageOutput() ResponsesOutput {
|
||||
return ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: generateItemID(),
|
||||
Role: "assistant",
|
||||
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
|
||||
Status: "completed",
|
||||
}
|
||||
}
|
||||
|
||||
func chatMessageContentText(raw json.RawMessage) string {
|
||||
raw = bytesTrimSpace(raw)
|
||||
if len(raw) == 0 || string(raw) == "null" {
|
||||
return ""
|
||||
}
|
||||
var text string
|
||||
if err := json.Unmarshal(raw, &text); err == nil {
|
||||
return text
|
||||
}
|
||||
var parts []ChatContentPart
|
||||
if err := json.Unmarshal(raw, &parts); err == nil {
|
||||
var texts []string
|
||||
for _, part := range parts {
|
||||
if part.Type == "text" && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses
|
||||
// usage shape.
|
||||
func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage {
|
||||
if usage == nil {
|
||||
return nil
|
||||
}
|
||||
out := &ResponsesUsage{
|
||||
InputTokens: usage.PromptTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
}
|
||||
if out.TotalTokens == 0 {
|
||||
out.TotalTokens = out.InputTokens + out.OutputTokens
|
||||
}
|
||||
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
out.InputTokensDetails = &ResponsesInputTokensDetails{
|
||||
CachedTokens: usage.PromptTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ChatCompletionsToResponsesStreamState tracks state while converting Chat
|
||||
// Completions SSE chunks into Responses SSE events.
|
||||
type ChatCompletionsToResponsesStreamState struct {
|
||||
ResponseID string
|
||||
Model string
|
||||
Created int64
|
||||
SequenceNumber int
|
||||
CreatedSent bool
|
||||
CompletedSent bool
|
||||
|
||||
MessageItemID string
|
||||
Text strings.Builder
|
||||
Reasoning strings.Builder
|
||||
ToolCalls map[int]*ChatToolCall
|
||||
|
||||
FinishReason string
|
||||
Usage *ResponsesUsage
|
||||
}
|
||||
|
||||
// NewChatCompletionsToResponsesStreamState returns an initialized stream state.
|
||||
func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState {
|
||||
return &ChatCompletionsToResponsesStreamState{
|
||||
ResponseID: generateResponsesID(),
|
||||
Model: model,
|
||||
Created: time.Now().Unix(),
|
||||
ToolCalls: make(map[int]*ChatToolCall),
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream
|
||||
// chunk into zero or more Responses stream events.
|
||||
func ChatCompletionsChunkToResponsesEvents(
|
||||
chunk *ChatCompletionsChunk,
|
||||
state *ChatCompletionsToResponsesStreamState,
|
||||
) []ResponsesStreamEvent {
|
||||
if chunk == nil || state == nil {
|
||||
return nil
|
||||
}
|
||||
if chunk.ID != "" {
|
||||
state.ResponseID = chunk.ID
|
||||
}
|
||||
if state.Model == "" && chunk.Model != "" {
|
||||
state.Model = chunk.Model
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
state.Usage = ChatUsageToResponsesUsage(chunk.Usage)
|
||||
}
|
||||
|
||||
var events []ResponsesStreamEvent
|
||||
events = append(events, ensureChatToResponsesCreated(state)...)
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != nil {
|
||||
events = append(events, ensureChatToResponsesMessageItem(state)...)
|
||||
_, _ = state.Text.WriteString(*choice.Delta.Content)
|
||||
events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Delta: *choice.Delta.Content,
|
||||
ItemID: state.MessageItemID,
|
||||
}))
|
||||
}
|
||||
if choice.Delta.ReasoningContent != nil {
|
||||
_, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent)
|
||||
events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
|
||||
OutputIndex: 0,
|
||||
SummaryIndex: 0,
|
||||
Delta: *choice.Delta.ReasoningContent,
|
||||
}))
|
||||
}
|
||||
for _, toolCall := range choice.Delta.ToolCalls {
|
||||
idx := 0
|
||||
if toolCall.Index != nil {
|
||||
idx = *toolCall.Index
|
||||
}
|
||||
stored, ok := state.ToolCalls[idx]
|
||||
if !ok {
|
||||
copyCall := toolCall
|
||||
if copyCall.ID == "" {
|
||||
copyCall.ID = generateItemID()
|
||||
}
|
||||
copyCall.Type = "function"
|
||||
state.ToolCalls[idx] = ©Call
|
||||
stored = ©Call
|
||||
events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||
OutputIndex: idx + 1,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "function_call",
|
||||
ID: generateItemID(),
|
||||
CallID: stored.ID,
|
||||
Name: stored.Function.Name,
|
||||
Status: "in_progress",
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
if toolCall.ID != "" {
|
||||
stored.ID = toolCall.ID
|
||||
}
|
||||
if toolCall.Function.Name != "" {
|
||||
stored.Function.Name = toolCall.Function.Name
|
||||
}
|
||||
}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
stored.Function.Arguments += toolCall.Function.Arguments
|
||||
events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
|
||||
OutputIndex: idx + 1,
|
||||
Delta: toolCall.Function.Arguments,
|
||||
CallID: stored.ID,
|
||||
Name: stored.Function.Name,
|
||||
}))
|
||||
}
|
||||
}
|
||||
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
||||
state.FinishReason = *choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// FinalizeChatCompletionsResponsesStream emits terminal Responses events.
|
||||
func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
|
||||
if state == nil || state.CompletedSent {
|
||||
return nil
|
||||
}
|
||||
var events []ResponsesStreamEvent
|
||||
events = append(events, ensureChatToResponsesCreated(state)...)
|
||||
if state.MessageItemID != "" {
|
||||
events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Text: state.Text.String(),
|
||||
ItemID: state.MessageItemID,
|
||||
}))
|
||||
events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: state.MessageItemID,
|
||||
Role: "assistant",
|
||||
Status: "completed",
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
status := "completed"
|
||||
var incompleteDetails *ResponsesIncompleteDetails
|
||||
if state.FinishReason == "length" {
|
||||
status = "incomplete"
|
||||
incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
|
||||
}
|
||||
|
||||
state.CompletedSent = true
|
||||
events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{
|
||||
Response: &ResponsesResponse{
|
||||
ID: state.ResponseID,
|
||||
Object: "response",
|
||||
Model: state.Model,
|
||||
Status: status,
|
||||
Output: state.chatOutput(),
|
||||
Usage: state.Usage,
|
||||
IncompleteDetails: incompleteDetails,
|
||||
},
|
||||
}))
|
||||
return events
|
||||
}
|
||||
|
||||
func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
|
||||
if state.CreatedSent {
|
||||
return nil
|
||||
}
|
||||
state.CreatedSent = true
|
||||
return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{
|
||||
Response: &ResponsesResponse{
|
||||
ID: state.ResponseID,
|
||||
Object: "response",
|
||||
Model: state.Model,
|
||||
Status: "in_progress",
|
||||
Output: []ResponsesOutput{},
|
||||
},
|
||||
})}
|
||||
}
|
||||
|
||||
func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
|
||||
if state.MessageItemID != "" {
|
||||
return nil
|
||||
}
|
||||
state.MessageItemID = generateItemID()
|
||||
return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||
OutputIndex: 0,
|
||||
Item: &ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: state.MessageItemID,
|
||||
Role: "assistant",
|
||||
Status: "in_progress",
|
||||
},
|
||||
})}
|
||||
}
|
||||
|
||||
func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput {
|
||||
var outputs []ResponsesOutput
|
||||
if state.Reasoning.Len() > 0 {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "reasoning",
|
||||
ID: generateItemID(),
|
||||
Summary: []ResponsesSummary{{
|
||||
Type: "summary_text",
|
||||
Text: state.Reasoning.String(),
|
||||
}},
|
||||
})
|
||||
}
|
||||
if state.MessageItemID != "" || len(state.ToolCalls) == 0 {
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "message",
|
||||
ID: nonEmpty(state.MessageItemID, generateItemID()),
|
||||
Role: "assistant",
|
||||
Content: []ResponsesContentPart{{
|
||||
Type: "output_text",
|
||||
Text: state.Text.String(),
|
||||
}},
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
for i := 0; i < len(state.ToolCalls); i++ {
|
||||
toolCall, ok := state.ToolCalls[i]
|
||||
if !ok || toolCall == nil {
|
||||
continue
|
||||
}
|
||||
arguments := toolCall.Function.Arguments
|
||||
if strings.TrimSpace(arguments) == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
outputs = append(outputs, ResponsesOutput{
|
||||
Type: "function_call",
|
||||
ID: generateItemID(),
|
||||
CallID: toolCall.ID,
|
||||
Name: toolCall.Function.Name,
|
||||
Arguments: arguments,
|
||||
Status: "completed",
|
||||
})
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
|
||||
func chatToResponsesEvent(
|
||||
state *ChatCompletionsToResponsesStreamState,
|
||||
eventType string,
|
||||
template *ResponsesStreamEvent,
|
||||
) ResponsesStreamEvent {
|
||||
seq := state.SequenceNumber
|
||||
state.SequenceNumber++
|
||||
evt := *template
|
||||
evt.Type = eventType
|
||||
evt.SequenceNumber = seq
|
||||
return evt
|
||||
}
|
||||
|
||||
func rawString(raw json.RawMessage) string {
|
||||
raw = bytesTrimSpace(raw)
|
||||
if len(raw) == 0 || string(raw) == "null" {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func rawNestedString(raw json.RawMessage, key string) string {
|
||||
var obj map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||
return ""
|
||||
}
|
||||
return rawString(obj[key])
|
||||
}
|
||||
|
||||
func bytesTrimSpace(raw json.RawMessage) json.RawMessage {
|
||||
return json.RawMessage(strings.TrimSpace(string(raw)))
|
||||
}
|
||||
|
||||
func nonEmpty(value, fallback string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@ -30,6 +30,17 @@ var CodexOfficialClientOriginatorPrefixes = []string{
|
||||
"codex ",
|
||||
}
|
||||
|
||||
// IsBrowserUserAgent 判断 User-Agent 是否来自浏览器(Chrome/Firefox/Safari/Edge/Opera 等)。
|
||||
// 所有现代浏览器的 UA 均以 "Mozilla/" 作为前缀,CLI 工具(codex/claude/curl/postman/python-requests 等)不会。
|
||||
// 该判定用于避免 Cloudflare 对浏览器型 UA 在 OpenAI 上游接口上触发 JS 质询。
|
||||
func IsBrowserUserAgent(userAgent string) bool {
|
||||
ua := strings.TrimSpace(userAgent)
|
||||
if ua == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(ua), "mozilla/")
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
ua := normalizeCodexClientHeader(userAgent)
|
||||
|
||||
@ -198,6 +198,19 @@ type APIKeyUsageTrendPoint struct {
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// APIKeyDailyUsagePoint represents one day of usage for a single API key.
|
||||
type APIKeyDailyUsagePoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens int64 `json:"cache_write_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
|
||||
219
backend/internal/repository/aes_encryptor_test.go
Normal file
219
backend/internal/repository/aes_encryptor_test.go
Normal file
@ -0,0 +1,219 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ── 测试辅助 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
// aesHexKey 构造一个全填充为 b 的 n 字节密钥并以 hex 编码返回。
|
||||
func aesHexKey(n int, b byte) string {
|
||||
raw := make([]byte, n)
|
||||
for i := range raw {
|
||||
raw[i] = b
|
||||
}
|
||||
return hex.EncodeToString(raw)
|
||||
}
|
||||
|
||||
// aesTestCfg 用给定 hex 密钥字符串构造最小 Config。
|
||||
func aesTestCfg(keyHex string) *config.Config {
|
||||
return &config.Config{
|
||||
Totp: config.TotpConfig{EncryptionKey: keyHex},
|
||||
}
|
||||
}
|
||||
|
||||
// aesEncryptor 创建一个持有合法 32 字节密钥的加密器,测试失败时立即终止。
|
||||
func aesEncryptor(t *testing.T) *AESEncryptor {
|
||||
t.Helper()
|
||||
enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x42)))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, enc)
|
||||
return enc.(*AESEncryptor)
|
||||
}
|
||||
|
||||
// ── NewAESEncryptor ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestNewAESEncryptor_ValidKey32Bytes(t *testing.T) {
|
||||
enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x01)))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, enc)
|
||||
}
|
||||
|
||||
// 16 / 24 字节密钥在 AES 体系内合法,但本实现仅接受 AES-256(32 字节)。
|
||||
func TestNewAESEncryptor_WrongKeyLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keySize int
|
||||
}{
|
||||
{"16_bytes_AES128", 16},
|
||||
{"24_bytes_AES192", 24},
|
||||
{"1_byte", 1},
|
||||
{"31_bytes", 31},
|
||||
{"33_bytes", 33},
|
||||
{"64_bytes", 64},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewAESEncryptor(aesTestCfg(aesHexKey(tt.keySize, 0x00)))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "32 bytes")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// "配置缺失"场景:空字符串与非法 hex 编码。
|
||||
func TestNewAESEncryptor_MissingOrInvalidConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyHex string
|
||||
wantContain string
|
||||
}{
|
||||
{"empty_key", "", "32 bytes"},
|
||||
{"invalid_hex_odd_length", "abcde", "invalid totp encryption key"},
|
||||
{"invalid_hex_chars", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", "invalid totp encryption key"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewAESEncryptor(aesTestCfg(tt.keyHex))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── 加解密往返(Roundtrip)───────────────────────────────────────────────────
|
||||
|
||||
func TestAESEncryptor_RoundTrip(t *testing.T) {
|
||||
enc := aesEncryptor(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
plaintext string
|
||||
}{
|
||||
{"ascii", "Hello, Sub2API!"},
|
||||
{"chinese_multibyte", "你好,世界!这是多字节 UTF-8 文本。"},
|
||||
{"empty_string", ""},
|
||||
{"long_string_gt_1KB", strings.Repeat("x", 2048)},
|
||||
{"special_chars", "!@#$%^&*()_+-=[]{}|;':\",./<>?"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ct, err := enc.Encrypt(tt.plaintext)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ct, "密文不应为空(即便明文为空字符串)")
|
||||
|
||||
got, err := enc.Decrypt(ct)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.plaintext, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── IV/Nonce 随机性 ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestAESEncryptor_Encrypt_NonceRandomness(t *testing.T) {
|
||||
enc := aesEncryptor(t)
|
||||
const iterations = 30
|
||||
plaintext := "same plaintext for every iteration"
|
||||
|
||||
seen := make(map[string]struct{}, iterations)
|
||||
for i := 0; i < iterations; i++ {
|
||||
ct, err := enc.Encrypt(plaintext)
|
||||
require.NoError(t, err)
|
||||
seen[ct] = struct{}{}
|
||||
}
|
||||
|
||||
// 30 次加密相同明文,每次因随机 Nonce 应产生不同密文。
|
||||
assert.Len(t, seen, iterations,
|
||||
"每次加密应因随机 Nonce 产生唯一密文,共 %d 次", iterations)
|
||||
}
|
||||
|
||||
// ── Decrypt 错误路径 ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestAESDecrypt_InvalidBase64(t *testing.T) {
|
||||
enc := aesEncryptor(t)
|
||||
_, err := enc.Decrypt("!!!not-valid-base64!!!")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode base64")
|
||||
}
|
||||
|
||||
func TestAESDecrypt_TooShort(t *testing.T) {
|
||||
enc := aesEncryptor(t)
|
||||
// GCM Nonce 为 12 字节;仅提供 2 字节,必然短于 NonceSize。
|
||||
short := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02})
|
||||
_, err := enc.Decrypt(short)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too short")
|
||||
}
|
||||
|
||||
func TestAESDecrypt_TamperedCiphertext(t *testing.T) {
|
||||
enc := aesEncryptor(t)
|
||||
|
||||
ct, err := enc.Encrypt("sensitive payload")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := base64.StdEncoding.DecodeString(ct)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Nonce 占前 12 字节;翻转其后第一个字节(密文体)。
|
||||
raw[12] ^= 0xFF
|
||||
_, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw))
|
||||
require.Error(t, err, "篡改密文体后解密应失败")
|
||||
}
|
||||
|
||||
func TestAESDecrypt_TamperedTag(t *testing.T) {
|
||||
enc := aesEncryptor(t)
|
||||
|
||||
ct, err := enc.Encrypt("sensitive payload")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := base64.StdEncoding.DecodeString(ct)
|
||||
require.NoError(t, err)
|
||||
|
||||
// GCM 认证标签占最后 16 字节;翻转最后一个字节。
|
||||
raw[len(raw)-1] ^= 0xFF
|
||||
_, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw))
|
||||
require.Error(t, err, "篡改 GCM 标签后解密应失败")
|
||||
}
|
||||
|
||||
// ── 跨实例(Cross-instance)──────────────────────────────────────────────────
|
||||
|
||||
func TestAESEncryptor_CrossInstance_SameKey_CanDecrypt(t *testing.T) {
|
||||
keyHex := aesHexKey(32, 0xDE)
|
||||
|
||||
enc1, err := NewAESEncryptor(aesTestCfg(keyHex))
|
||||
require.NoError(t, err)
|
||||
enc2, err := NewAESEncryptor(aesTestCfg(keyHex))
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := "cross-instance roundtrip"
|
||||
ct, err := enc1.Encrypt(plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := enc2.Decrypt(ct)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, got, "相同密钥构造的两个实例应可互相解密")
|
||||
}
|
||||
|
||||
func TestAESEncryptor_CrossInstance_DifferentKey_CannotDecrypt(t *testing.T) {
|
||||
enc1, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xAA)))
|
||||
require.NoError(t, err)
|
||||
enc2, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xBB)))
|
||||
require.NoError(t, err)
|
||||
|
||||
ct, err := enc1.Encrypt("secret message")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = enc2.Decrypt(ct)
|
||||
require.Error(t, err, "不同密钥的实例不应能解密对方的密文")
|
||||
}
|
||||
@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
func TestGroupRepository_DeleteCascade_PreservesApiKeyGroupID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
@ -138,8 +138,10 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
// API keys keep their group_id so auth can reject keys bound to a deleted group.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
require.NotNil(t, keyAfter.GroupID)
|
||||
require.Equal(t, targetGroup.ID, *keyAfter.GroupID)
|
||||
require.Nil(t, keyAfter.Group)
|
||||
}
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@ -94,9 +93,13 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
total, active, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = total
|
||||
out.ActiveAccountCount = active
|
||||
counts, err := r.loadAccountCounts(ctx, []int64{out.ID})
|
||||
if err == nil {
|
||||
c := counts[out.ID]
|
||||
out.AccountCount = c.Total
|
||||
out.ActiveAccountCount = c.Active
|
||||
out.RateLimitedAccountCount = c.RateLimited
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@ -538,15 +541,12 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
|
||||
var rateLimited int64
|
||||
err = scanSingleRow(ctx, r.sql,
|
||||
`SELECT COUNT(*),
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||
a.rate_limit_reset_at > NOW() OR
|
||||
a.overload_until > NOW() OR
|
||||
a.temp_unschedulable_until > NOW()
|
||||
))
|
||||
fmt.Sprintf(`SELECT
|
||||
COUNT(*) FILTER (WHERE a.deleted_at IS NULL),
|
||||
COUNT(*) FILTER (WHERE %s),
|
||||
COUNT(*) FILTER (WHERE %s)
|
||||
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
|
||||
WHERE ag.group_id = $1`,
|
||||
WHERE ag.group_id = $1`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL),
|
||||
[]any{groupID}, &total, &active, &rateLimited)
|
||||
return
|
||||
}
|
||||
@ -636,28 +636,18 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Remove the group id from user_allowed_groups join table.
|
||||
// 2. Remove the group id from user_allowed_groups join table.
|
||||
// Legacy users.allowed_groups 列已弃用,不再同步。
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Delete account_groups join rows.
|
||||
// 3. Delete account_groups join rows.
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. Soft-delete group itself.
|
||||
// 4. Soft-delete group itself.
|
||||
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -680,6 +670,28 @@ type groupAccountCounts struct {
|
||||
RateLimited int64
|
||||
}
|
||||
|
||||
const (
|
||||
// 分组页的"可用"账号数必须与账号仓储的 ListSchedulableByGroupID 过滤口径一致。
|
||||
groupAccountAvailableSQL = `a.deleted_at IS NULL
|
||||
AND a.status = 'active'
|
||||
AND a.schedulable = true
|
||||
AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE)
|
||||
AND (a.rate_limit_reset_at IS NULL OR a.rate_limit_reset_at <= NOW())
|
||||
AND (a.overload_until IS NULL OR a.overload_until <= NOW())
|
||||
AND (a.temp_unschedulable_until IS NULL OR a.temp_unschedulable_until <= NOW())`
|
||||
|
||||
// 这里沿用历史字段名 RateLimitedAccountCount,但统计的是会让账号暂时退出调度的时间窗口。
|
||||
groupAccountTemporarilyLimitedSQL = `a.deleted_at IS NULL
|
||||
AND a.status = 'active'
|
||||
AND a.schedulable = true
|
||||
AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE)
|
||||
AND (
|
||||
a.rate_limit_reset_at > NOW() OR
|
||||
a.overload_until > NOW() OR
|
||||
a.temp_unschedulable_until > NOW()
|
||||
)`
|
||||
)
|
||||
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
|
||||
counts = make(map[int64]groupAccountCounts, len(groupIDs))
|
||||
if len(groupIDs) == 0 {
|
||||
@ -688,18 +700,14 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
`SELECT ag.group_id,
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
|
||||
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
||||
a.rate_limit_reset_at > NOW() OR
|
||||
a.overload_until > NOW() OR
|
||||
a.temp_unschedulable_until > NOW()
|
||||
)) AS rate_limited
|
||||
fmt.Sprintf(`SELECT ag.group_id,
|
||||
COUNT(*) FILTER (WHERE a.deleted_at IS NULL) AS total,
|
||||
COUNT(*) FILTER (WHERE %s) AS active,
|
||||
COUNT(*) FILTER (WHERE %s) AS rate_limited
|
||||
FROM account_groups ag
|
||||
JOIN accounts a ON a.id = ag.account_id
|
||||
WHERE ag.group_id = ANY($1)
|
||||
GROUP BY ag.group_id`,
|
||||
GROUP BY ag.group_id`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL),
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@ -651,6 +651,164 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// TestListWithFilters_ActiveAccountCount_LessThanTotal 验证 ActiveAccountCount 正确区分可用与不可用账号。
|
||||
// 当分组内存在 disabled 或 schedulable=false 的账号时,ActiveAccountCount 必须小于 AccountCount,
|
||||
// 且与 GetAccountCount 返回的 active 值一致。
|
||||
func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() {
|
||||
g := &service.Group{
|
||||
Name: "g-mixed-status",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
|
||||
insertAccount := func(name, status string, schedulable bool) int64 {
|
||||
var id int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type, status, schedulable) VALUES ($1, $2, $3, $4, $5) RETURNING id",
|
||||
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status, schedulable},
|
||||
&id,
|
||||
))
|
||||
return id
|
||||
}
|
||||
link := func(accountID int64, priority int) {
|
||||
_, err := s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
accountID, g.ID, priority)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// account 1: active + schedulable → counts toward both total and active
|
||||
link(insertAccount("acc-active-sched", service.StatusActive, true), 1)
|
||||
// account 2: disabled → counts toward total only
|
||||
link(insertAccount("acc-disabled", service.StatusDisabled, true), 2)
|
||||
// account 3: active + not schedulable → counts toward total only
|
||||
link(insertAccount("acc-unschedulable", service.StatusActive, false), 3)
|
||||
|
||||
// --- ListWithFilters path ---
|
||||
isExclusive := false
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx,
|
||||
pagination.PaginationParams{Page: 1, PageSize: 100},
|
||||
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
|
||||
var found *service.Group
|
||||
for i := range groups {
|
||||
if groups[i].ID == g.ID {
|
||||
found = &groups[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
|
||||
s.Assert().Equal(int64(3), found.AccountCount, "AccountCount must count all 3 accounts")
|
||||
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must count only the active+schedulable account")
|
||||
|
||||
// --- GetAccountCount must return identical values ---
|
||||
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
|
||||
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
|
||||
}
|
||||
|
||||
// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。
|
||||
// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池,
|
||||
// 因此 ActiveAccountCount 必须与真实调度查询口径一致。
|
||||
func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() {
|
||||
g := &service.Group{
|
||||
Name: "g-rate-limited",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
|
||||
var normalID int64
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"acc-normal", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&normalID))
|
||||
|
||||
var rateLimitedID int64
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type, rate_limit_reset_at) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
|
||||
[]any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&rateLimitedID))
|
||||
|
||||
var overloadedID int64
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
|
||||
[]any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&overloadedID))
|
||||
|
||||
var tempUnschedulableID int64
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
|
||||
[]any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&tempUnschedulableID))
|
||||
|
||||
var expiredID int64
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id",
|
||||
[]any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&expiredID))
|
||||
|
||||
_, err := s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
normalID, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
rateLimitedID, g.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
overloadedID, g.ID, 3)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
tempUnschedulableID, g.ID, 4)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
expiredID, g.ID, 5)
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := false
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx,
|
||||
pagination.PaginationParams{Page: 1, PageSize: 100},
|
||||
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
|
||||
var found *service.Group
|
||||
for i := range groups {
|
||||
if groups[i].ID == g.ID {
|
||||
found = &groups[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
|
||||
s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts")
|
||||
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts")
|
||||
s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts")
|
||||
|
||||
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
|
||||
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
|
||||
|
||||
detail, err := s.repo.GetByID(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters")
|
||||
s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters")
|
||||
s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters")
|
||||
}
|
||||
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
|
||||
@ -7,6 +7,67 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TestListWithAccountCountSort_AttachesActiveCount 验证通过 account_count 排序时,
|
||||
// ActiveAccountCount 与 AccountCount 都被正确附加到返回结果中,
|
||||
// 且排序基于 total 账号数而非 active 账号数。
|
||||
func (s *GroupRepoSuite) TestListWithAccountCountSort_AttachesActiveCount() {
|
||||
// Group A: 2 total, 1 active (1 disabled account)
|
||||
gA := &service.Group{Name: "sort-count-a", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard}
|
||||
// Group B: 1 total, 1 active
|
||||
gB := &service.Group{Name: "sort-count-b", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, gA))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, gB))
|
||||
|
||||
insertAccount := func(name, status string) int64 {
|
||||
var id int64
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||
"INSERT INTO accounts (name, platform, type, status) VALUES ($1, $2, $3, $4) RETURNING id",
|
||||
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status},
|
||||
&id))
|
||||
return id
|
||||
}
|
||||
link := func(accountID, groupID int64, priority int) {
|
||||
_, err := s.tx.ExecContext(s.ctx,
|
||||
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||
accountID, groupID, priority)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// gA: 1 active + 1 disabled → total=2, active=1
|
||||
link(insertAccount("sa-active", service.StatusActive), gA.ID, 1)
|
||||
link(insertAccount("sa-disabled", service.StatusDisabled), gA.ID, 2)
|
||||
// gB: 1 active → total=1, active=1
|
||||
link(insertAccount("sb-active", service.StatusActive), gB.ID, 1)
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1, PageSize: 100, SortBy: "account_count", SortOrder: "desc",
|
||||
}, service.PlatformAnthropic, service.StatusActive, "", nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
byID := make(map[int64]service.Group, len(groups))
|
||||
for _, g := range groups {
|
||||
byID[g.ID] = g
|
||||
}
|
||||
|
||||
s.Require().Contains(byID, gA.ID, "gA must appear in results")
|
||||
s.Require().Contains(byID, gB.ID, "gB must appear in results")
|
||||
|
||||
cA := byID[gA.ID]
|
||||
s.Assert().Equal(int64(2), cA.AccountCount, "gA AccountCount must be 2")
|
||||
s.Assert().Equal(int64(1), cA.ActiveAccountCount, "gA ActiveAccountCount must be 1")
|
||||
|
||||
cB := byID[gB.ID]
|
||||
s.Assert().Equal(int64(1), cB.AccountCount, "gB AccountCount must be 1")
|
||||
s.Assert().Equal(int64(1), cB.ActiveAccountCount, "gB ActiveAccountCount must be 1")
|
||||
|
||||
// Sort is by total (not active): gA (total=2) must rank higher than gB (total=1) in desc order
|
||||
indexByID := make(map[int64]int, len(groups))
|
||||
for i, g := range groups {
|
||||
indexByID[g.ID] = i
|
||||
}
|
||||
s.Assert().Less(indexByID[gA.ID], indexByID[gB.ID], "gA (total=2) must rank above gB (total=1) with account_count desc")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
|
||||
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
|
||||
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
|
||||
|
||||
@ -833,6 +833,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_alipay_enabled": true,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": true,
|
||||
"openai_codex_user_agent": "",
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": []
|
||||
},
|
||||
@ -1058,6 +1059,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_alipay_enabled": false,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": false,
|
||||
"openai_codex_user_agent": "",
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": []
|
||||
},
|
||||
|
||||
@ -109,6 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
if abortIfAPIKeyGroupUnavailable(c, apiKey) {
|
||||
return
|
||||
}
|
||||
|
||||
// ── 4. SimpleMode → early return ─────────────────────────────
|
||||
|
||||
@ -251,3 +254,26 @@ func setGroupContext(c *gin.Context, group *service.Group) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
func abortIfAPIKeyGroupUnavailable(c *gin.Context, apiKey *service.APIKey) bool {
|
||||
code, message, ok := validateAPIKeyGroupAvailable(apiKey)
|
||||
if ok {
|
||||
return false
|
||||
}
|
||||
AbortWithError(c, 403, code, message)
|
||||
return true
|
||||
}
|
||||
|
||||
func validateAPIKeyGroupAvailable(apiKey *service.APIKey) (string, string, bool) {
|
||||
if apiKey == nil || apiKey.GroupID == nil {
|
||||
return "", "", true
|
||||
}
|
||||
group := apiKey.Group
|
||||
if group == nil || strings.EqualFold(group.Status, "deleted") {
|
||||
return "GROUP_DELETED", "API Key 所属分组已删除", false
|
||||
}
|
||||
if !group.IsActive() {
|
||||
return "GROUP_DISABLED", "API Key 所属分组已停用", false
|
||||
}
|
||||
return "", "", true
|
||||
}
|
||||
|
||||
@ -54,6 +54,10 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
abortWithGoogleError(c, 401, "User account is not active")
|
||||
return
|
||||
}
|
||||
if _, message, ok := validateAPIKeyGroupAvailable(apiKey); !ok {
|
||||
abortWithGoogleError(c, 403, message)
|
||||
return
|
||||
}
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
|
||||
@ -300,6 +300,104 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(101)
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
group *service.Group
|
||||
wantStatus int
|
||||
wantCode string
|
||||
}{
|
||||
{
|
||||
name: "active group passes",
|
||||
group: &service.Group{
|
||||
ID: groupID,
|
||||
Name: "active",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "disabled group is forbidden",
|
||||
group: &service.Group{
|
||||
ID: groupID,
|
||||
Name: "disabled",
|
||||
Status: service.StatusDisabled,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
},
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantCode: "GROUP_DISABLED",
|
||||
},
|
||||
{
|
||||
name: "deleted status group is forbidden",
|
||||
group: &service.Group{
|
||||
ID: groupID,
|
||||
Name: "deleted",
|
||||
Status: "deleted",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
},
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantCode: "GROUP_DELETED",
|
||||
},
|
||||
{
|
||||
name: "missing group edge is forbidden",
|
||||
group: nil,
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantCode: "GROUP_DELETED",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
GroupID: &groupID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: tt.group,
|
||||
}
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tt.wantStatus, w.Code)
|
||||
if tt.wantCode != "" {
|
||||
require.Contains(t, w.Body.String(), tt.wantCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -422,6 +422,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
|
||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||
adminSettings.GET("/email-templates", h.Admin.Setting.ListEmailTemplates)
|
||||
adminSettings.POST("/email-template-preview", h.Admin.Setting.PreviewEmailTemplate)
|
||||
adminSettings.GET("/email-templates/:event/:locale", h.Admin.Setting.GetEmailTemplate)
|
||||
adminSettings.PUT("/email-templates/:event/:locale", h.Admin.Setting.UpdateEmailTemplate)
|
||||
adminSettings.POST("/email-templates/:event/:locale/restore-official", h.Admin.Setting.RestoreOfficialEmailTemplate)
|
||||
// Admin API Key 管理
|
||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||
|
||||
@ -214,6 +214,7 @@ func RegisterAuthRoutes(
|
||||
settings := v1.Group("/settings")
|
||||
{
|
||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||
settings.GET("/email-unsubscribe", h.Setting.UnsubscribeNotificationEmail)
|
||||
}
|
||||
|
||||
// 需要认证的当前用户信息
|
||||
|
||||
@ -31,6 +31,7 @@ func RegisterUserRoutes(
|
||||
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
||||
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
|
||||
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
||||
user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage)
|
||||
|
||||
// 通知邮箱管理
|
||||
notifyEmail := user.Group("/notify-email")
|
||||
|
||||
@ -244,6 +244,21 @@ func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSor
|
||||
return nil
|
||||
}
|
||||
|
||||
type deleteGroupAPIKeyRepoStub struct {
|
||||
apiKeyRepoStubForGroupUpdate
|
||||
keys []string
|
||||
listErr error
|
||||
listGroupIDs []int64
|
||||
}
|
||||
|
||||
func (s *deleteGroupAPIKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
s.listGroupIDs = append(s.listGroupIDs, groupID)
|
||||
if s.listErr != nil {
|
||||
return nil, s.listErr
|
||||
}
|
||||
return s.keys, nil
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
@ -500,6 +515,23 @@ func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) {
|
||||
}, calls)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_InvalidatesAuthCacheForBoundKeys(t *testing.T) {
|
||||
repo := &groupRepoStub{}
|
||||
apiKeyRepo := &deleteGroupAPIKeyRepoStub{keys: []string{"k1", "k2"}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
groupRepo: repo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{5}, repo.deleteCalls)
|
||||
require.Equal(t, []int64{5}, apiKeyRepo.listGroupIDs)
|
||||
require.Equal(t, []string{"k1", "k2"}, invalidator.keys)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
|
||||
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
|
||||
const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
|
||||
@ -94,7 +94,7 @@ func (s *AuthService) BindEmailIdentity(
|
||||
}
|
||||
|
||||
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
|
||||
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
|
||||
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string, locale ...string) error {
|
||||
if s == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
@ -128,7 +128,7 @@ func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int6
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
|
||||
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName, firstEmailLocale(locale))
|
||||
}
|
||||
|
||||
func normalizeEmailForIdentityBinding(email string) (string, error) {
|
||||
|
||||
@ -28,7 +28,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
|
||||
|
||||
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
||||
// account-creation flows without relying on the public registration gate.
|
||||
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if email == "" {
|
||||
return nil, ErrEmailVerifyRequired
|
||||
@ -47,7 +47,7 @@ func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email stri
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
|
||||
if err := s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SendVerifyCodeResult{
|
||||
|
||||
@ -273,7 +273,7 @@ type SendVerifyCodeResult struct {
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string, locale ...string) error {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return ErrRegDisabled
|
||||
@ -307,11 +307,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
return s.emailService.SendVerifyCode(ctx, email, siteName)
|
||||
return s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale))
|
||||
}
|
||||
|
||||
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册(默认关闭)
|
||||
@ -352,7 +352,7 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
|
||||
// 异步发送
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email)
|
||||
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
|
||||
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName, firstEmailLocale(locale)); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err)
|
||||
return nil, fmt.Errorf("enqueue verify code: %w", err)
|
||||
}
|
||||
@ -1251,7 +1251,7 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB
|
||||
|
||||
// RequestPasswordReset 请求密码重置(同步发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
|
||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string, locale ...string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
@ -1264,7 +1264,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
@ -1275,7 +1275,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
|
||||
|
||||
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
|
||||
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string, locale ...string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
@ -1288,7 +1288,7 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
@ -39,9 +39,10 @@ type AccountQuotaReader interface {
|
||||
|
||||
// BalanceNotifyService handles balance and quota threshold notifications.
|
||||
type BalanceNotifyService struct {
|
||||
emailService *EmailService
|
||||
settingRepo SettingRepository
|
||||
accountRepo AccountQuotaReader
|
||||
emailService *EmailService
|
||||
settingRepo SettingRepository
|
||||
accountRepo AccountQuotaReader
|
||||
notificationEmailService *NotificationEmailService
|
||||
}
|
||||
|
||||
// NewBalanceNotifyService creates a new BalanceNotifyService.
|
||||
@ -53,6 +54,10 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BalanceNotifyService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||
s.notificationEmailService = notificationEmailService
|
||||
}
|
||||
|
||||
// resolveBalanceThreshold returns the effective balance threshold.
|
||||
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
|
||||
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
|
||||
@ -125,7 +130,7 @@ func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user
|
||||
slog.Error("panic in balance notification", "recover", r)
|
||||
}
|
||||
}()
|
||||
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
|
||||
s.sendBalanceLowEmails(recipients, user.ID, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
|
||||
}()
|
||||
}
|
||||
|
||||
@ -342,11 +347,44 @@ func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body str
|
||||
}
|
||||
|
||||
// sendBalanceLowEmails sends balance low notification to all recipients.
|
||||
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
|
||||
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userID int64, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
|
||||
displayName := userName
|
||||
if displayName == "" {
|
||||
displayName = userEmail
|
||||
}
|
||||
if s.notificationEmailService != nil {
|
||||
fallbackRecipients := make([]string, 0, len(recipients))
|
||||
for _, to := range recipients {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventBalanceLow,
|
||||
RecipientEmail: to,
|
||||
RecipientName: displayName,
|
||||
UserID: userID,
|
||||
SourceType: "balance_low",
|
||||
SourceID: firstNonEmpty(strconv.FormatInt(userID, 10), userEmail),
|
||||
ReminderKey: time.Now().UTC().Format("2006-01-02"),
|
||||
Variables: map[string]string{
|
||||
"current_balance": fmt.Sprintf("%.2f", balance),
|
||||
"threshold": fmt.Sprintf("%.2f", threshold),
|
||||
"recharge_url": rechargeURL,
|
||||
},
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
if shouldFallbackNotificationEmail(err) {
|
||||
slog.Warn("template balance low notification failed; falling back to built-in body", "to", to, "err", err.Error())
|
||||
fallbackRecipients = append(fallbackRecipients, to)
|
||||
} else {
|
||||
slog.Warn("template balance low notification delivery failed; not sending fallback to avoid duplicates", "to", to, "err", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(fallbackRecipients) == 0 {
|
||||
return
|
||||
}
|
||||
recipients = fallbackRecipients
|
||||
}
|
||||
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
|
||||
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
|
||||
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
|
||||
@ -369,6 +407,44 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
if s.notificationEmailService != nil {
|
||||
fallbackRecipients := make([]string, 0, len(adminEmails))
|
||||
for _, to := range adminEmails {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventAccountQuotaAlert,
|
||||
RecipientEmail: to,
|
||||
RecipientName: emailRecipientName(to),
|
||||
SourceType: "account_quota",
|
||||
SourceID: fmt.Sprintf("%d-%s", accountID, dim.name),
|
||||
ReminderKey: time.Now().UTC().Format("2006-01-02"),
|
||||
Variables: map[string]string{
|
||||
"account_id": strconv.FormatInt(accountID, 10),
|
||||
"account_name": accountName,
|
||||
"platform": platform,
|
||||
"quota_dimension": dimLabel,
|
||||
"quota_used": fmt.Sprintf("%.2f", used),
|
||||
"quota_limit": fmt.Sprintf("%.2f", dim.limit),
|
||||
"quota_remaining": fmt.Sprintf("%.2f", remaining),
|
||||
"quota_threshold": thresholdDisplay,
|
||||
},
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
if shouldFallbackNotificationEmail(err) {
|
||||
slog.Warn("template account quota alert failed; falling back to built-in body", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error())
|
||||
fallbackRecipients = append(fallbackRecipients, to)
|
||||
} else {
|
||||
slog.Warn("template account quota alert delivery failed; not sending fallback to avoid duplicates", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(fallbackRecipients) == 0 {
|
||||
return
|
||||
}
|
||||
adminEmails = fallbackRecipients
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
|
||||
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
|
||||
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
|
||||
|
||||
@ -15,16 +15,19 @@ import (
|
||||
//
|
||||
// 行为按 tokenType / mimicClaudeCode 分两条路径:
|
||||
//
|
||||
// OAuth mimic 路径 (tokenType == "oauth" && mimicClaudeCode):
|
||||
// 1. body 中 metadata.user_id 派生的 SessionID 是合法 UUID → canonicalize 写入
|
||||
// 2. 请求 header 中已有合法 UUID → canonicalize 保留
|
||||
// 3. 否则 → 兜底生成 UUID
|
||||
// OAuth 路径 (tokenType == "oauth"):
|
||||
// OAuth 账号本身就是真实 Claude Code 客户端的凭证,可以信任 body 中的
|
||||
// metadata.user_id 派生 session id。
|
||||
// 1. metadata.user_id 派生 SessionID 是合法 UUID → canonical 写入
|
||||
// 2. header 已有合法 UUID → canonical 保留
|
||||
// 3. mimicClaudeCode == true → 兜底生成新 UUID
|
||||
// (mimicClaudeCode == false 且无 metadata 时不强制注入)
|
||||
//
|
||||
// API key 透传 / 非 mimic 路径:
|
||||
// - 不从 body 合成 header(避免污染客户端原始语义)
|
||||
// - 但若客户端在 header 中传入了 X-Claude-Code-Session-Id:
|
||||
// 合法 UUID → canonicalize 保留
|
||||
// 非法值 → 删除(不向上游转发恶意值,符合 UUID 校验承诺)
|
||||
// API key 透传路径 (tokenType != "oauth"):
|
||||
// - 不从 body metadata 派生 header(避免污染客户端原始语义)
|
||||
// - 若客户端在 header 中传入 X-Claude-Code-Session-Id:
|
||||
// 合法 UUID → canonical 保留
|
||||
// 非法值 → 删除(不向上游转发恶意值)
|
||||
// - 不兜底生成
|
||||
//
|
||||
// 安全说明:metadata.user_id 由客户端控制,ParseMetadataUserID 的正则仅约束字符集,
|
||||
@ -37,10 +40,10 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string,
|
||||
req.Header = make(http.Header)
|
||||
}
|
||||
|
||||
isOAuthMimic := tokenType == "oauth" && mimicClaudeCode
|
||||
isOAuth := tokenType == "oauth"
|
||||
|
||||
// OAuth mimic 路径:从 metadata 派生(仅在 mimic 场景写 header)。
|
||||
if isOAuthMimic {
|
||||
// OAuth 路径:从 metadata 派生(OAuth 凭证可信任)。
|
||||
if isOAuth {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
if id, err := uuid.Parse(parsed.SessionID); err == nil {
|
||||
@ -65,9 +68,9 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string,
|
||||
req.Header.Del("X-Claude-Code-Session-Id")
|
||||
}
|
||||
|
||||
// OAuth mimic 兜底生成(仅 mimic 场景;API key 不污染)。
|
||||
// OAuth mimic 兜底生成(仅 mimic 场景;API key/非 mimic 不污染)。
|
||||
// uuid.NewString() 走 crypto/rand。
|
||||
if isOAuthMimic {
|
||||
if isOAuth && mimicClaudeCode {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", uuid.NewString())
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,15 +136,17 @@ func TestEnsureClaudeCodeSessionID_APIKeyIgnoresMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth 但非 mimic 模式也不应该从 metadata 派生 header。
|
||||
func TestEnsureClaudeCodeSessionID_OAuthNonMimicIgnoresMetadata(t *testing.T) {
|
||||
// OAuth 路径即使 mimic=false 也应该从 metadata 派生 header:
|
||||
// OAuth 凭证本身就是 Claude Code 类型账号,metadata.user_id 可信任。
|
||||
// 这与 API key 路径不同(API key 是任意第三方调用方)。
|
||||
func TestEnsureClaudeCodeSessionID_OAuthNonMimicDerivesFromMetadata(t *testing.T) {
|
||||
req := newReq(t)
|
||||
body := []byte(`{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"` + testValidUUID + `\"}"}}`)
|
||||
ensureClaudeCodeSessionID(req, body, "oauth", false)
|
||||
|
||||
got := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id")
|
||||
if got != "" {
|
||||
t.Fatalf("Non-mimic OAuth must NOT derive session-id from metadata, got %q", got)
|
||||
if got != testValidUUID {
|
||||
t.Fatalf("OAuth must derive session-id from metadata regardless of mimic flag, got %q want %q", got, testValidUUID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1463,6 +1463,24 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context,
|
||||
|
||||
func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
|
||||
siteName := s.siteName(ctx)
|
||||
if s.emailService.notificationEmailService != nil {
|
||||
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventContentModerationViolation,
|
||||
RecipientEmail: log.UserEmail,
|
||||
RecipientName: emailRecipientName(log.UserEmail),
|
||||
UserID: contentModerationEmailUserID(log),
|
||||
SourceType: "content_moderation",
|
||||
SourceID: contentModerationEmailSourceID(log),
|
||||
Variables: contentModerationEmailVariables(log, cfg),
|
||||
}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
if !shouldFallbackNotificationEmail(err) {
|
||||
return err
|
||||
}
|
||||
slog.Warn("template content moderation violation email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error())
|
||||
}
|
||||
}
|
||||
subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName))
|
||||
body := buildContentModerationViolationEmailBody(siteName, log, cfg)
|
||||
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
|
||||
@ -1470,11 +1488,71 @@ func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *
|
||||
|
||||
func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
|
||||
siteName := s.siteName(ctx)
|
||||
if s.emailService.notificationEmailService != nil {
|
||||
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventContentModerationDisabled,
|
||||
RecipientEmail: log.UserEmail,
|
||||
RecipientName: emailRecipientName(log.UserEmail),
|
||||
UserID: contentModerationEmailUserID(log),
|
||||
SourceType: "content_moderation",
|
||||
SourceID: contentModerationEmailSourceID(log),
|
||||
Variables: contentModerationEmailVariables(log, cfg),
|
||||
}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
if !shouldFallbackNotificationEmail(err) {
|
||||
return err
|
||||
}
|
||||
slog.Warn("template content moderation disabled email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error())
|
||||
}
|
||||
}
|
||||
subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName))
|
||||
body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg)
|
||||
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
|
||||
}
|
||||
|
||||
func contentModerationEmailUserID(log *ContentModerationLog) int64 {
|
||||
if log == nil || log.UserID == nil {
|
||||
return 0
|
||||
}
|
||||
return *log.UserID
|
||||
}
|
||||
|
||||
func contentModerationEmailSourceID(log *ContentModerationLog) string {
|
||||
if log == nil || log.ID <= 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d", log.ID)
|
||||
}
|
||||
|
||||
func contentModerationEmailVariables(log *ContentModerationLog, cfg *ContentModerationConfig) map[string]string {
|
||||
variables := map[string]string{
|
||||
"triggered_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"group_name": "-",
|
||||
"moderation_category": "-",
|
||||
"moderation_score": "0.000",
|
||||
"violation_count": "0",
|
||||
"ban_threshold": "0",
|
||||
}
|
||||
if log != nil {
|
||||
if !log.CreatedAt.IsZero() {
|
||||
variables["triggered_at"] = log.CreatedAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if strings.TrimSpace(log.GroupName) != "" {
|
||||
variables["group_name"] = strings.TrimSpace(log.GroupName)
|
||||
}
|
||||
if strings.TrimSpace(log.HighestCategory) != "" {
|
||||
variables["moderation_category"] = strings.TrimSpace(log.HighestCategory)
|
||||
}
|
||||
variables["moderation_score"] = fmt.Sprintf("%.3f", log.HighestScore)
|
||||
variables["violation_count"] = fmt.Sprintf("%d", log.ViolationCount)
|
||||
}
|
||||
if cfg != nil {
|
||||
variables["ban_threshold"] = fmt.Sprintf("%d", cfg.BanThreshold)
|
||||
}
|
||||
return variables
|
||||
}
|
||||
|
||||
func (s *ContentModerationService) siteName(ctx context.Context) string {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return "Sub2API"
|
||||
|
||||
@ -401,6 +401,10 @@ const (
|
||||
SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control"
|
||||
// SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值)
|
||||
SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version"
|
||||
// SettingKeyOpenAICodexUserAgent OpenAI Codex 完整 User-Agent(空值使用内置默认)
|
||||
// 当客户端 UA 被识别为浏览器(Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值,
|
||||
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
|
||||
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
|
||||
|
||||
// Balance Low Notification
|
||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||
|
||||
@ -21,6 +21,7 @@ type EmailTask struct {
|
||||
SiteName string
|
||||
TaskType string // "verify_code" or "password_reset"
|
||||
ResetURL string // Only used for password_reset task type
|
||||
Locale string // Optional Accept-Language locale hint
|
||||
}
|
||||
|
||||
// EmailQueueService 异步邮件队列服务
|
||||
@ -82,13 +83,13 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
|
||||
switch task.TaskType {
|
||||
case TaskTypeVerifyCode:
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName, task.Locale); err != nil {
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
case TaskTypePasswordReset:
|
||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL, task.Locale); err != nil {
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||
@ -99,11 +100,12 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
}
|
||||
|
||||
// EnqueueVerifyCode 将验证码发送任务加入队列
|
||||
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string, locale ...string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: TaskTypeVerifyCode,
|
||||
Locale: firstEmailLocale(locale),
|
||||
}
|
||||
|
||||
select {
|
||||
@ -116,12 +118,13 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
}
|
||||
|
||||
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
||||
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
|
||||
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string, locale ...string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: TaskTypePasswordReset,
|
||||
ResetURL: resetURL,
|
||||
Locale: firstEmailLocale(locale),
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
@ -94,8 +94,9 @@ type SMTPConfig struct {
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo SettingRepository
|
||||
cache EmailCache
|
||||
settingRepo SettingRepository
|
||||
cache EmailCache
|
||||
notificationEmailService *NotificationEmailService
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
@ -106,6 +107,28 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EmailService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||
s.notificationEmailService = notificationEmailService
|
||||
}
|
||||
|
||||
func firstEmailLocale(locales []string) string {
|
||||
if len(locales) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(locales[0])
|
||||
}
|
||||
|
||||
func emailRecipientName(email string) string {
|
||||
trimmed := strings.TrimSpace(email)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if at := strings.Index(trimmed, "@"); at > 0 {
|
||||
return trimmed[:at]
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
keys := []string{
|
||||
@ -301,7 +324,7 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送验证码邮件
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string, locale ...string) error {
|
||||
// 检查是否在冷却期内
|
||||
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
@ -327,6 +350,26 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
}
|
||||
|
||||
if s.notificationEmailService != nil {
|
||||
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventAuthVerifyCode,
|
||||
Locale: firstEmailLocale(locale),
|
||||
RecipientEmail: email,
|
||||
RecipientName: emailRecipientName(email),
|
||||
Variables: map[string]string{
|
||||
"verification_code": code,
|
||||
"expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)),
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !shouldFallbackNotificationEmail(err) {
|
||||
return err
|
||||
}
|
||||
slog.Warn("failed to send templated verification email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err)
|
||||
}
|
||||
|
||||
// 构建邮件内容
|
||||
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
|
||||
body := s.buildVerifyCodeEmailBody(code, siteName)
|
||||
@ -469,7 +512,7 @@ func (s *EmailService) GeneratePasswordResetToken() (string, error) {
|
||||
}
|
||||
|
||||
// SendPasswordResetEmail sends a password reset email with a reset link
|
||||
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
|
||||
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string, locale ...string) error {
|
||||
var token string
|
||||
var needSaveToken bool
|
||||
|
||||
@ -502,6 +545,26 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
|
||||
// Build full reset URL with URL-encoded token and email
|
||||
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
||||
|
||||
if s.notificationEmailService != nil {
|
||||
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventAuthPasswordReset,
|
||||
Locale: firstEmailLocale(locale),
|
||||
RecipientEmail: email,
|
||||
RecipientName: emailRecipientName(email),
|
||||
Variables: map[string]string{
|
||||
"reset_url": fullResetURL,
|
||||
"expires_in_minutes": strconv.Itoa(int(passwordResetTokenTTL / time.Minute)),
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !shouldFallbackNotificationEmail(err) {
|
||||
return err
|
||||
}
|
||||
slog.Warn("failed to send templated password reset email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err)
|
||||
}
|
||||
|
||||
// Build email content
|
||||
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
||||
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
||||
@ -516,7 +579,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
|
||||
|
||||
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
||||
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string, locale ...string) error {
|
||||
// Check email cooldown to prevent email bombing
|
||||
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||
slog.Info("password reset email skipped due to cooldown", "email", email)
|
||||
@ -524,7 +587,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
|
||||
}
|
||||
|
||||
// Send email using core method
|
||||
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
1347
backend/internal/service/notification_email_service.go
Normal file
1347
backend/internal/service/notification_email_service.go
Normal file
File diff suppressed because it is too large
Load Diff
571
backend/internal/service/notification_email_service_test.go
Normal file
571
backend/internal/service/notification_email_service_test.go
Normal file
@ -0,0 +1,571 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNotificationEmailPreviewEscapesHTMLAndSanitizesSubject(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||
|
||||
preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
|
||||
Event: NotificationEmailEventBalanceLow,
|
||||
Locale: "en-US,en;q=0.9",
|
||||
Subject: "Low balance for {{recipient_name}}\r\nInjected",
|
||||
HTML: `<p>{{recipient_name}}</p><a href="{{recharge_url}}">Recharge</a>`,
|
||||
Variables: map[string]string{
|
||||
"recipient_name": `<script>alert("x")</script>`,
|
||||
"recharge_url": `javascript:alert(1)`,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, preview.Subject, "\r")
|
||||
require.NotContains(t, preview.Subject, "\n")
|
||||
require.Contains(t, preview.Subject, `Low balance for <script>alert("x")</script>Injected`)
|
||||
require.Contains(t, preview.HTML, `<script>alert("x")</script>`)
|
||||
require.NotContains(t, preview.HTML, `javascript:alert(1)`)
|
||||
require.Contains(t, preview.HTML, `href=""`)
|
||||
}
|
||||
|
||||
func TestNotificationEmailTemplateOverrideAndRestore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newNotificationEmailMemorySettingRepo()
|
||||
svc := NewNotificationEmailService(repo, nil)
|
||||
|
||||
official, err := svc.GetTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "en")
|
||||
require.NoError(t, err)
|
||||
require.False(t, official.IsCustom)
|
||||
|
||||
updated, err := svc.UpdateTemplate(
|
||||
ctx,
|
||||
NotificationEmailEventBalanceRechargeSuccess,
|
||||
"zh-Hans",
|
||||
"充值完成:{{recharge_amount}}",
|
||||
"<p>{{recipient_name}} 已充值 {{recharge_amount}}</p>",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updated.IsCustom)
|
||||
require.Equal(t, "zh", updated.Locale)
|
||||
require.Equal(t, "充值完成:{{recharge_amount}}", updated.Subject)
|
||||
require.NotNil(t, updated.UpdatedAt)
|
||||
|
||||
restored, err := svc.RestoreOfficialTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "zh")
|
||||
require.NoError(t, err)
|
||||
require.False(t, restored.IsCustom)
|
||||
require.NotEqual(t, updated.Subject, restored.Subject)
|
||||
_, err = repo.GetValue(ctx, notificationEmailTemplateKey(NotificationEmailEventBalanceRechargeSuccess, "zh"))
|
||||
require.ErrorIs(t, err, ErrSettingNotFound)
|
||||
}
|
||||
|
||||
func TestNotificationEmailTemplateRejectsUnsupportedPlaceholder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||
|
||||
_, err := svc.UpdateTemplate(
|
||||
ctx,
|
||||
NotificationEmailEventSubscriptionPurchaseSuccess,
|
||||
"en",
|
||||
"Purchased {{not_allowed}}",
|
||||
"<p>{{subscription_group}}</p>",
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "unsupported placeholder")
|
||||
}
|
||||
|
||||
func TestNotificationEmailAuthTemplatesAreListedAndPreviewable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||
|
||||
infos := svc.ListEventInfos()
|
||||
events := make(map[string]NotificationEmailEventInfo, len(infos))
|
||||
for _, info := range infos {
|
||||
events[info.Event] = info
|
||||
}
|
||||
require.Contains(t, events, NotificationEmailEventAuthVerifyCode)
|
||||
require.Contains(t, events, NotificationEmailEventAuthPasswordReset)
|
||||
require.False(t, events[NotificationEmailEventAuthVerifyCode].Optional)
|
||||
require.False(t, events[NotificationEmailEventAuthPasswordReset].Optional)
|
||||
require.Contains(t, events[NotificationEmailEventAuthVerifyCode].Placeholders, "verification_code")
|
||||
require.Contains(t, events[NotificationEmailEventAuthPasswordReset].Placeholders, "reset_url")
|
||||
|
||||
verifyPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
|
||||
Event: NotificationEmailEventAuthVerifyCode,
|
||||
Locale: "zh-CN",
|
||||
Variables: map[string]string{
|
||||
"verification_code": "654321",
|
||||
"expires_in_minutes": "15",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, verifyPreview.Subject, "邮箱验证码")
|
||||
require.Contains(t, verifyPreview.HTML, "654321")
|
||||
|
||||
resetPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
|
||||
Event: NotificationEmailEventAuthPasswordReset,
|
||||
Locale: "en",
|
||||
Variables: map[string]string{
|
||||
"reset_url": "https://example.com/reset?token=abc",
|
||||
"expires_in_minutes": "30",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resetPreview.Subject, "Password reset")
|
||||
require.Contains(t, resetPreview.HTML, "https://example.com/reset?token=abc")
|
||||
}
|
||||
|
||||
func TestNotificationEmailAdditionalEventsAreListedAndPreviewable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||
|
||||
infos := svc.ListEventInfos()
|
||||
events := make(map[string]NotificationEmailEventInfo, len(infos))
|
||||
for _, info := range infos {
|
||||
events[info.Event] = info
|
||||
}
|
||||
|
||||
checks := []struct {
|
||||
event string
|
||||
placeholder string
|
||||
}{
|
||||
{NotificationEmailEventNotificationEmailVerifyCode, "verification_code"},
|
||||
{NotificationEmailEventAccountQuotaAlert, "account_name"},
|
||||
{NotificationEmailEventContentModerationViolation, "moderation_category"},
|
||||
{NotificationEmailEventContentModerationDisabled, "violation_count"},
|
||||
{NotificationEmailEventOpsAlert, "rule_name"},
|
||||
{NotificationEmailEventOpsScheduledReport, "report_html"},
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
info, ok := events[check.event]
|
||||
require.Truef(t, ok, "expected %s to be listed", check.event)
|
||||
require.False(t, info.Optional)
|
||||
require.Contains(t, info.Placeholders, check.placeholder)
|
||||
|
||||
preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{Event: check.event, Locale: "zh"})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, preview.Subject)
|
||||
require.NotEmpty(t, preview.HTML)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationEmailRawHTMLVariablesAreTrustedOnlyForHTMLPlaceholders(t *testing.T) {
|
||||
require.True(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "report_html"))
|
||||
require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "recipient_name"))
|
||||
require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsAlert, "report_html"))
|
||||
|
||||
preview, err := renderNotificationEmail(
|
||||
NotificationEmailEventOpsScheduledReport,
|
||||
"Report for {{recipient_name}}",
|
||||
`<section>{{report_html}}</section><p>{{recipient_name}}</p>`,
|
||||
map[string]string{
|
||||
"recipient_name": `<script>alert("x")</script>`,
|
||||
"report_html": `<p>escaped report</p>`,
|
||||
},
|
||||
map[string]string{
|
||||
"report_html": `<table><tr><td>trusted report</td></tr></table>`,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, preview.HTML, `<table><tr><td>trusted report</td></tr></table>`)
|
||||
require.NotContains(t, preview.HTML, `escaped report`)
|
||||
require.Contains(t, preview.HTML, `<script>alert("x")</script>`)
|
||||
require.Contains(t, preview.Subject, `<script>alert("x")</script>`)
|
||||
|
||||
preview, err = renderNotificationEmail(
|
||||
NotificationEmailEventOpsScheduledReport,
|
||||
"Recipient {{recipient_name}}",
|
||||
`<p>{{recipient_name}}</p>`,
|
||||
map[string]string{"recipient_name": `<em>escaped</em>`},
|
||||
map[string]string{"recipient_name": `<strong>raw</strong>`},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, preview.HTML, `<em>escaped</em>`)
|
||||
require.NotContains(t, preview.HTML, `<strong>raw</strong>`)
|
||||
}
|
||||
|
||||
func TestNotificationEmailFallbackClassification(t *testing.T) {
|
||||
templateErr := notificationEmailTemplateErr(errors.New("bad template"))
|
||||
configErr := notificationEmailConfigErr(errors.New("missing email service"))
|
||||
deliveryErr := notificationEmailDeliveryErr(errors.New("smtp timeout"))
|
||||
|
||||
require.True(t, shouldFallbackNotificationEmail(templateErr))
|
||||
require.True(t, shouldFallbackNotificationEmail(configErr))
|
||||
require.False(t, shouldFallbackNotificationEmail(deliveryErr))
|
||||
require.True(t, isNotificationEmailDeliveryError(deliveryErr))
|
||||
require.False(t, isNotificationEmailDeliveryError(templateErr))
|
||||
require.False(t, shouldFallbackNotificationEmail(nil))
|
||||
}
|
||||
|
||||
func TestEmailQueueTasksPreserveLocaleHints(t *testing.T) {
|
||||
queue := &EmailQueueService{taskChan: make(chan EmailTask, 2)}
|
||||
require.NoError(t, queue.EnqueueVerifyCode("user@example.com", "Sub2API", "zh-CN"))
|
||||
require.NoError(t, queue.EnqueuePasswordReset("user@example.com", "Sub2API", "https://example.com/reset", "en-US"))
|
||||
|
||||
verifyTask := <-queue.taskChan
|
||||
require.Equal(t, TaskTypeVerifyCode, verifyTask.TaskType)
|
||||
require.Equal(t, "zh-CN", verifyTask.Locale)
|
||||
|
||||
resetTask := <-queue.taskChan
|
||||
require.Equal(t, TaskTypePasswordReset, resetTask.TaskType)
|
||||
require.Equal(t, "en-US", resetTask.Locale)
|
||||
}
|
||||
|
||||
func TestOpsScheduledReportDeliverySourceIDIncludesReportIdentity(t *testing.T) {
|
||||
report := &opsScheduledReport{Name: "日报", ReportType: "daily_summary", Schedule: "0 9 * * *"}
|
||||
sourceID := opsScheduledReportDeliverySourceID(report)
|
||||
require.Contains(t, sourceID, "daily_summary")
|
||||
require.Contains(t, sourceID, "日报")
|
||||
require.Contains(t, sourceID, "0 9 * * *")
|
||||
require.NotEqual(t, sourceID, opsScheduledReportDeliverySourceID(&opsScheduledReport{Name: "周报", ReportType: "weekly_summary", Schedule: "0 9 * * 1"}))
|
||||
require.Equal(t, "scheduled_report", opsScheduledReportDeliverySourceID(nil))
|
||||
}
|
||||
|
||||
func TestNotificationEmailUnsubscribeOnlyAllowsOptionalEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||
|
||||
token, err := svc.createUnsubscribeToken(ctx, "User@Example.com", NotificationEmailEventBalanceLow)
|
||||
require.NoError(t, err)
|
||||
result, err := svc.Unsubscribe(ctx, token)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Done)
|
||||
require.Equal(t, NotificationEmailEventBalanceLow, result.Event)
|
||||
unsubscribed, err := svc.IsUnsubscribed(ctx, "user@example.com", NotificationEmailEventBalanceLow)
|
||||
require.NoError(t, err)
|
||||
require.True(t, unsubscribed)
|
||||
|
||||
transactionalToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventBalanceRechargeSuccess)
|
||||
require.NoError(t, err)
|
||||
_, err = svc.Unsubscribe(ctx, transactionalToken)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "transactional")
|
||||
|
||||
authToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventAuthVerifyCode)
|
||||
require.NoError(t, err)
|
||||
_, err = svc.Unsubscribe(ctx, authToken)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "transactional")
|
||||
}
|
||||
|
||||
func TestNotificationEmailLocaleMemoryNormalizesAcceptLanguage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||
|
||||
svc.RememberRecipientLocale(ctx, 42, "User@Example.com", "zh-CN,zh;q=0.9,en;q=0.8")
|
||||
require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 42, "user@example.com"))
|
||||
require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 0, "user@example.com"))
|
||||
}
|
||||
|
||||
func TestNotificationEmailDeliveryKeyUsesShortStableHash(t *testing.T) {
|
||||
key := notificationEmailDeliveryKey(
|
||||
NotificationEmailEventSubscriptionExpiryReminder,
|
||||
"user_subscription",
|
||||
"1234567890",
|
||||
"User@Example.com",
|
||||
"7d",
|
||||
)
|
||||
require.NotEmpty(t, key)
|
||||
require.LessOrEqual(t, len(key), 100)
|
||||
require.True(t, strings.HasPrefix(key, notificationEmailDeliveryKeyPrefix+"v2:"))
|
||||
require.Equal(t, key, notificationEmailDeliveryKey(
|
||||
NotificationEmailEventSubscriptionExpiryReminder,
|
||||
"user_subscription",
|
||||
"1234567890",
|
||||
"user@example.com",
|
||||
"7d",
|
||||
))
|
||||
require.NotEqual(t, key, notificationEmailDeliveryKey(
|
||||
NotificationEmailEventSubscriptionExpiryReminder,
|
||||
"user_subscription",
|
||||
"1234567890",
|
||||
"user@example.com",
|
||||
"3d",
|
||||
))
|
||||
|
||||
legacyKey := legacyNotificationEmailDeliveryKey(
|
||||
NotificationEmailEventSubscriptionExpiryReminder,
|
||||
"user_subscription",
|
||||
"1234567890",
|
||||
"user@example.com",
|
||||
"7d",
|
||||
)
|
||||
require.Greater(t, len(legacyKey), 100)
|
||||
}
|
||||
|
||||
func TestNotificationEmailPreferenceKeyUsesShortStableHashAndReadsLegacyKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newNotificationEmailMemorySettingRepo()
|
||||
svc := NewNotificationEmailService(repo, nil)
|
||||
|
||||
key := notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "User@Example.com")
|
||||
require.NotEmpty(t, key)
|
||||
require.LessOrEqual(t, len(key), 100)
|
||||
require.True(t, strings.HasPrefix(key, notificationEmailPreferenceKeyPrefix+"v2:"))
|
||||
require.Equal(t, key, notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com"))
|
||||
|
||||
legacyKey := legacyNotificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com")
|
||||
require.Greater(t, len(legacyKey), 100)
|
||||
require.NoError(t, repo.Set(ctx, legacyKey, "unsubscribed"))
|
||||
|
||||
unsubscribed, err := svc.IsUnsubscribed(ctx, "User@Example.com", NotificationEmailEventSubscriptionExpiryReminder)
|
||||
require.NoError(t, err)
|
||||
require.True(t, unsubscribed)
|
||||
}
|
||||
|
||||
func TestNotificationEmailSendDeduplicatesSubscriptionExpiryReminder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newNotificationEmailMemorySettingRepo()
|
||||
smtpServer := startNotificationEmailTestSMTPServer(t)
|
||||
require.NoError(t, repo.SetMultiple(ctx, smtpServer.settings()))
|
||||
|
||||
emailSvc := NewEmailService(repo, nil)
|
||||
svc := NewNotificationEmailService(repo, emailSvc)
|
||||
input := NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventSubscriptionExpiryReminder,
|
||||
RecipientEmail: "User@Example.com",
|
||||
RecipientName: "User",
|
||||
UserID: 42,
|
||||
SourceType: "user_subscription",
|
||||
SourceID: "1234567890",
|
||||
ReminderKey: "7d",
|
||||
Variables: map[string]string{
|
||||
"subscription_group": "Codex",
|
||||
"expiry_time": "2026-05-27 12:00",
|
||||
"days_remaining": "7",
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, svc.Send(ctx, input))
|
||||
require.Equal(t, int64(1), smtpServer.messageCount())
|
||||
|
||||
key := notificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey)
|
||||
require.LessOrEqual(t, len(key), 100)
|
||||
_, err := repo.GetValue(ctx, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, svc.Send(ctx, input))
|
||||
require.Equal(t, int64(1), smtpServer.messageCount())
|
||||
}
|
||||
|
||||
func TestNotificationEmailSendRespectsLegacyDeliveryKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newNotificationEmailMemorySettingRepo()
|
||||
svc := NewNotificationEmailService(repo, nil)
|
||||
input := NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventSubscriptionExpiryReminder,
|
||||
RecipientEmail: "user@example.com",
|
||||
SourceType: "user_subscription",
|
||||
SourceID: "1234567890",
|
||||
ReminderKey: "7d",
|
||||
}
|
||||
legacyKey := legacyNotificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey)
|
||||
require.NoError(t, repo.Set(ctx, legacyKey, "sent"))
|
||||
|
||||
require.NoError(t, svc.Send(ctx, input))
|
||||
}
|
||||
|
||||
type notificationEmailMemorySettingRepo struct {
|
||||
mu sync.RWMutex
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func newNotificationEmailMemorySettingRepo() *notificationEmailMemorySettingRepo {
|
||||
return ¬ificationEmailMemorySettingRepo{values: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
value, ok := r.values[key]
|
||||
if !ok {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
return &Setting{Key: key, Value: value}, nil
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
setting, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return setting.Value, nil
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) Set(_ context.Context, key, value string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := r.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make(map[string]string, len(r.values))
|
||||
for key, value := range r.values {
|
||||
out[key] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *notificationEmailMemorySettingRepo) Delete(_ context.Context, key string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if _, ok := r.values[key]; !ok {
|
||||
return ErrSettingNotFound
|
||||
}
|
||||
delete(r.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNotificationEmailMemorySettingRepoSatisfiesInterface(t *testing.T) {
|
||||
var _ SettingRepository = (*notificationEmailMemorySettingRepo)(nil)
|
||||
require.False(t, strings.Contains(notificationEmailPreferenceKey(NotificationEmailEventBalanceLow, "User@Example.com"), "User@Example.com"))
|
||||
}
|
||||
|
||||
type notificationEmailTestSMTPServer struct {
|
||||
listener net.Listener
|
||||
wg sync.WaitGroup
|
||||
messages atomic.Int64
|
||||
}
|
||||
|
||||
func startNotificationEmailTestSMTPServer(t *testing.T) *notificationEmailTestSMTPServer {
|
||||
t.Helper()
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
server := ¬ificationEmailTestSMTPServer{listener: listener}
|
||||
server.wg.Add(1)
|
||||
go server.serve()
|
||||
t.Cleanup(server.close)
|
||||
return server
|
||||
}
|
||||
|
||||
func (s *notificationEmailTestSMTPServer) settings() map[string]string {
|
||||
host, port, _ := net.SplitHostPort(s.listener.Addr().String())
|
||||
return map[string]string{
|
||||
SettingKeySMTPHost: host,
|
||||
SettingKeySMTPPort: port,
|
||||
SettingKeySMTPUsername: "user",
|
||||
SettingKeySMTPPassword: "password",
|
||||
SettingKeySMTPFrom: "noreply@example.com",
|
||||
SettingKeySMTPFromName: "Sub2API",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *notificationEmailTestSMTPServer) messageCount() int64 {
|
||||
return s.messages.Load()
|
||||
}
|
||||
|
||||
func (s *notificationEmailTestSMTPServer) close() {
|
||||
_ = s.listener.Close()
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *notificationEmailTestSMTPServer) serve() {
|
||||
defer s.wg.Done()
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *notificationEmailTestSMTPServer) handleConn(conn net.Conn) {
|
||||
defer func() { _ = conn.Close() }()
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
writeLine := func(line string) bool {
|
||||
if _, err := rw.WriteString(line + "\r\n"); err != nil {
|
||||
return false
|
||||
}
|
||||
return rw.Flush() == nil
|
||||
}
|
||||
if !writeLine("220 localhost ESMTP") {
|
||||
return
|
||||
}
|
||||
for {
|
||||
line, err := rw.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cmd := strings.ToUpper(strings.TrimRight(line, "\r\n"))
|
||||
switch {
|
||||
case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"):
|
||||
if _, err := rw.WriteString("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil {
|
||||
return
|
||||
}
|
||||
if err := rw.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(cmd, "AUTH"):
|
||||
if !writeLine("235 2.7.0 Authentication successful") {
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(cmd, "MAIL FROM:"):
|
||||
if !writeLine("250 2.1.0 OK") {
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(cmd, "RCPT TO:"):
|
||||
if !writeLine("250 2.1.5 OK") {
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(cmd, "DATA"):
|
||||
if !writeLine("354 End data with <CR><LF>.<CR><LF>") {
|
||||
return
|
||||
}
|
||||
for {
|
||||
dataLine, err := rw.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimRight(dataLine, "\r\n") == "." {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.messages.Add(1)
|
||||
if !writeLine("250 2.0.0 OK") {
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(cmd, "QUIT"):
|
||||
_ = writeLine("221 2.0.0 Bye")
|
||||
return
|
||||
default:
|
||||
if !writeLine("250 OK") {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,428 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// forwardResponsesViaRawChatCompletions serves /v1/responses clients through an
|
||||
// upstream that only supports /v1/chat/completions.
|
||||
func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var responsesReq apicompat.ResponsesRequest
|
||||
if err := json.Unmarshal(body, &responsesReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "Failed to parse request body",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("parse responses request: %w", err)
|
||||
}
|
||||
originalModel := strings.TrimSpace(responsesReq.Model)
|
||||
if originalModel == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "model is required",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("missing model in request")
|
||||
}
|
||||
|
||||
clientStream := responsesReq.Stream
|
||||
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
|
||||
serviceTier := extractOpenAIServiceTierFromBody(body)
|
||||
|
||||
chatReq, err := apicompat.ResponsesToChatCompletionsRequest(&responsesReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("convert responses to chat completions: %w", err)
|
||||
}
|
||||
|
||||
billingModel := resolveOpenAIForwardModel(account, originalModel, "")
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||
chatReq.Model = upstreamModel
|
||||
if clientStream {
|
||||
chatReq.StreamOptions = &apicompat.ChatStreamOptions{IncludeUsage: true}
|
||||
}
|
||||
|
||||
chatBody, err := json.Marshal(chatReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal chat completions fallback request: %w", err)
|
||||
}
|
||||
chatBody, err = s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, chatBody)
|
||||
if err != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(err, &blocked) {
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if serviceTier == nil {
|
||||
serviceTier = extractOpenAIServiceTierFromBody(chatBody)
|
||||
}
|
||||
|
||||
logger.L().Debug("openai responses: forwarding via raw chat completions",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("billing_model", billingModel),
|
||||
zap.String("upstream_model", upstreamModel),
|
||||
zap.Bool("stream", clientStream),
|
||||
)
|
||||
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||
}
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||
}
|
||||
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(chatBody))
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if clientStream {
|
||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
upstreamReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
upstreamReq.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
|
||||
upstreamReq.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, chatBody)
|
||||
}
|
||||
|
||||
if clientStream {
|
||||
return s.streamChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
return s.bufferChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) bufferChatCompletionsAsResponses(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Failed to read upstream response",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||
}
|
||||
|
||||
var ccResp apicompat.ChatCompletionsResponse
|
||||
if err := json.Unmarshal(respBody, &ccResp); err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Failed to parse upstream response",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("parse chat completions response: %w", err)
|
||||
}
|
||||
responsesResp := apicompat.ChatCompletionsResponseToResponses(&ccResp, originalModel)
|
||||
|
||||
usage := OpenAIUsage{}
|
||||
if parsed, ok := extractOpenAIUsageFromJSONBytes(respBody); ok {
|
||||
usage = parsed
|
||||
}
|
||||
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.JSON(http.StatusOK, responsesResp)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) streamChatCompletionsAsResponses(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
originalModel string,
|
||||
billingModel string,
|
||||
upstreamModel string,
|
||||
reasoningEffort *string,
|
||||
serviceTier *string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
headersWritten := false
|
||||
writeStreamHeaders := func() {
|
||||
if headersWritten {
|
||||
return
|
||||
}
|
||||
headersWritten = true
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
state := apicompat.NewChatCompletionsToResponsesStreamState(originalModel)
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
|
||||
writeEvents := func(events []apicompat.ResponsesStreamEvent) {
|
||||
if clientDisconnected || len(events) == 0 {
|
||||
return
|
||||
}
|
||||
writeStreamHeaders()
|
||||
for _, event := range events {
|
||||
sse, err := apicompat.ResponsesEventToSSE(event)
|
||||
if err != nil {
|
||||
logger.L().Warn("openai responses chat fallback: failed to marshal stream event",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Debug("openai responses chat fallback: client disconnected, continuing to drain upstream for billing",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
payload, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
payload = strings.TrimSpace(payload)
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
if payload == "[DONE]" {
|
||||
sawDone = true
|
||||
break
|
||||
}
|
||||
|
||||
if u := extractCCStreamUsage(payload); u != nil {
|
||||
usage = *u
|
||||
}
|
||||
|
||||
var chunk apicompat.ChatCompletionsChunk
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
logger.L().Warn("openai responses chat fallback: failed to parse chat stream chunk",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
if firstTokenMs == nil && !isOpenAIChatUsageOnlyStreamChunk(payload) && chatChunkStartsResponsesOutput(&chunk) {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
writeEvents(apicompat.ChatCompletionsChunkToResponsesEvents(&chunk, state))
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.L().Warn("openai responses chat fallback: stream read error",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
|
||||
writeEvents(apicompat.FinalizeChatCompletionsResponsesStream(state))
|
||||
if !clientDisconnected {
|
||||
writeStreamHeaders()
|
||||
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
if !sawDone {
|
||||
logger.L().Debug("openai responses chat fallback: upstream stream ended without done sentinel",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: billingModel,
|
||||
UpstreamModel: upstreamModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
ServiceTier: serviceTier,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func chatChunkStartsResponsesOutput(chunk *apicompat.ChatCompletionsChunk) bool {
|
||||
if chunk == nil {
|
||||
return false
|
||||
}
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != nil || choice.Delta.ReasoningContent != nil || len(choice.Delta.ToolCalls) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -0,0 +1,145 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestForwardResponses_ForceChatCompletionsRoutesNonStreamingToChatCompletions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_chat_json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"id":"chatcmpl_json","object":"chat.completion","model":"gpt-5.4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5,"prompt_tokens_details":{"cached_tokens":1}}}`,
|
||||
)),
|
||||
}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
|
||||
require.Equal(t, "response", gjson.Get(rec.Body.String(), "object").String())
|
||||
require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String())
|
||||
require.Equal(t, 3, result.Usage.InputTokens)
|
||||
require.Equal(t, 2, result.Usage.OutputTokens)
|
||||
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
|
||||
require.False(t, result.Stream)
|
||||
}
|
||||
|
||||
func TestForwardResponses_ForceChatCompletionsRoutesStreamingToChatCompletions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"he"},"finish_reason":null}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"llo"},"finish_reason":null}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_resp_chat_stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
require.Contains(t, rec.Body.String(), "event: response.output_text.delta")
|
||||
require.Contains(t, rec.Body.String(), `"delta":"he"`)
|
||||
require.Contains(t, rec.Body.String(), "event: response.completed")
|
||||
require.Contains(t, rec.Body.String(), `"input_tokens":4`)
|
||||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||
require.Equal(t, 4, result.Usage.InputTokens)
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.True(t, result.Stream)
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
}
|
||||
|
||||
func TestForwardResponses_AutoSupportedAccountStillUsesResponsesEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_native"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"id":"resp_native","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"ok"}],"status":"completed"}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}`,
|
||||
)),
|
||||
}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
account.Extra = map[string]any{
|
||||
openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeAuto),
|
||||
openai_compat.ExtraKeyResponsesSupported: true,
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "http://upstream.example/v1/responses", upstream.lastReq.URL.String())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, "messages").Exists())
|
||||
require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String())
|
||||
}
|
||||
|
||||
func forceChatResponsesFallbackAccount() *Account {
|
||||
account := rawChatCompletionsTestAccount()
|
||||
account.Extra = map[string]any{
|
||||
openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeForceChatCompletions),
|
||||
}
|
||||
return account
|
||||
}
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
@ -2018,6 +2019,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
originalBody := body
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
|
||||
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.forwardResponsesViaRawChatCompletions(ctx, c, account, body)
|
||||
}
|
||||
|
||||
compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body)
|
||||
setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)
|
||||
|
||||
@ -3231,6 +3237,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
||||
req.Header.Set("user-agent", codexCLIUserAgent)
|
||||
}
|
||||
|
||||
// 浏览器型 UA 兜底:仅 OAuth(ChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器
|
||||
// (Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA,避免 Cloudflare 触发 JS 质询。
|
||||
s.overrideBrowserUserAgent(ctx, account, req)
|
||||
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
@ -3947,6 +3957,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
req.Header.Set("user-agent", codexCLIUserAgent)
|
||||
}
|
||||
|
||||
// 浏览器型 UA 兜底:仅 OAuth(ChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器
|
||||
// (Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA,避免 Cloudflare 触发 JS 质询。
|
||||
s.overrideBrowserUserAgent(ctx, account, req)
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
@ -3955,6 +3969,30 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// overrideBrowserUserAgent 检查请求的最终 user-agent,若为浏览器 UA 则替换为后台配置的 Codex UA。
|
||||
// 用于规避 Cloudflare 对浏览器型 UA 在 ChatGPT 内部接口上的访问质询。
|
||||
// 影响范围严格限定:仅 OAuth(Codex/ChatGPT 内部接口)账号生效;API Key 等其他账号原样透传。
|
||||
// 仅在识别为浏览器(Mozilla/...)时改写,其他 CLI/工具 UA 不动。
|
||||
func (s *OpenAIGatewayService) overrideBrowserUserAgent(ctx context.Context, account *Account, req *http.Request) {
|
||||
if req == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
currentUA := req.Header.Get("user-agent")
|
||||
if !openai.IsBrowserUserAgent(currentUA) {
|
||||
return
|
||||
}
|
||||
codexUA := DefaultOpenAICodexUserAgent
|
||||
if s != nil && s.settingService != nil {
|
||||
if v := strings.TrimSpace(s.settingService.GetOpenAICodexUserAgent(ctx)); v != "" {
|
||||
codexUA = v
|
||||
}
|
||||
}
|
||||
req.Header.Set("user-agent", codexUA)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
|
||||
@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
||||
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
|
||||
tool, _ = sjson.SetBytes(tool, "action", action)
|
||||
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
|
||||
if shouldPassOpenAIImagesN(toolModel, parsed.N) {
|
||||
tool, _ = sjson.SetBytes(tool, "n", parsed.N)
|
||||
}
|
||||
|
||||
for _, field := range []struct {
|
||||
path string
|
||||
@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func shouldPassOpenAIImagesN(model string, n int) bool {
|
||||
if n <= 1 {
|
||||
return false
|
||||
}
|
||||
return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3")
|
||||
}
|
||||
|
||||
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
|
||||
if gjson.GetBytes(payload, "type").String() != "response.completed" {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
|
||||
@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
account.Type,
|
||||
len(parsed.Uploads),
|
||||
)
|
||||
if parsed.N > 1 {
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
|
||||
parsed.N,
|
||||
requestModel,
|
||||
parsed.Endpoint,
|
||||
)
|
||||
}
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
defer releaseUpstreamCtx()
|
||||
|
||||
|
||||
@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string)
|
||||
return openAIImageTestSSEEvent{}, false
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
"X-Request-Id": []string{"req_img_123"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-image-2", result.Model)
|
||||
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 3, result.ImageCount)
|
||||
require.Equal(t, 11, result.Usage.InputTokens)
|
||||
require.Equal(t, 22, result.Usage.OutputTokens)
|
||||
require.Equal(t, 7, result.Usage.ImageOutputTokens)
|
||||
@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
||||
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
|
||||
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
|
||||
require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int())
|
||||
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||
require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3)
|
||||
require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String())
|
||||
require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String())
|
||||
require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
|
||||
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
|
||||
func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesGenerationsEndpoint,
|
||||
Model: "gpt-image-2",
|
||||
@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t
|
||||
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, body)
|
||||
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||
require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int())
|
||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
|
||||
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesGenerationsEndpoint,
|
||||
Model: "dall-e-3",
|
||||
Prompt: "draw a cat",
|
||||
N: 2,
|
||||
}
|
||||
|
||||
body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, body)
|
||||
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||
require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String())
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesEditsEndpoint,
|
||||
|
||||
@ -686,6 +686,21 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
|
||||
if !s.emailLimiter.Allow(time.Now().UTC()) {
|
||||
continue
|
||||
}
|
||||
if s.emailService.notificationEmailService != nil {
|
||||
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventOpsAlert,
|
||||
RecipientEmail: addr,
|
||||
RecipientName: emailRecipientName(addr),
|
||||
SourceType: "ops_alert",
|
||||
SourceID: fmt.Sprintf("%d", event.ID),
|
||||
Variables: opsAlertEmailVariables(rule, event),
|
||||
}); err == nil {
|
||||
anySent = true
|
||||
continue
|
||||
} else if !shouldFallbackNotificationEmail(err) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
|
||||
// Ignore per-recipient failures; continue best-effort.
|
||||
continue
|
||||
@ -699,6 +714,46 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
|
||||
return anySent
|
||||
}
|
||||
|
||||
func opsAlertEmailVariables(rule *OpsAlertRule, event *OpsAlertEvent) map[string]string {
|
||||
variables := map[string]string{
|
||||
"rule_name": "-",
|
||||
"severity": "-",
|
||||
"alert_status": "-",
|
||||
"metric_type": "-",
|
||||
"operator": "-",
|
||||
"metric_value": "-",
|
||||
"threshold_value": "-",
|
||||
"triggered_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"alert_description": "-",
|
||||
}
|
||||
if rule != nil {
|
||||
variables["rule_name"] = strings.TrimSpace(rule.Name)
|
||||
variables["severity"] = strings.TrimSpace(rule.Severity)
|
||||
variables["metric_type"] = strings.TrimSpace(rule.MetricType)
|
||||
variables["operator"] = strings.TrimSpace(rule.Operator)
|
||||
variables["threshold_value"] = fmt.Sprintf("%.2f", rule.Threshold)
|
||||
if strings.TrimSpace(rule.Description) != "" {
|
||||
variables["alert_description"] = strings.TrimSpace(rule.Description)
|
||||
}
|
||||
}
|
||||
if event != nil {
|
||||
variables["alert_status"] = strings.TrimSpace(event.Status)
|
||||
if event.MetricValue != nil {
|
||||
variables["metric_value"] = fmt.Sprintf("%.2f", *event.MetricValue)
|
||||
}
|
||||
if event.ThresholdValue != nil {
|
||||
variables["threshold_value"] = fmt.Sprintf("%.2f", *event.ThresholdValue)
|
||||
}
|
||||
if !event.FiredAt.IsZero() {
|
||||
variables["triggered_at"] = event.FiredAt.UTC().Format(time.RFC3339)
|
||||
}
|
||||
if strings.TrimSpace(event.Description) != "" {
|
||||
variables["alert_description"] = strings.TrimSpace(event.Description)
|
||||
}
|
||||
}
|
||||
return variables
|
||||
}
|
||||
|
||||
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||
if rule == nil || event == nil {
|
||||
return ""
|
||||
|
||||
@ -337,6 +337,7 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
|
||||
templateVariables := opsScheduledReportEmailVariables(report, now)
|
||||
|
||||
attempts := 0
|
||||
for _, to := range recipients {
|
||||
@ -345,6 +346,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
||||
continue
|
||||
}
|
||||
attempts++
|
||||
if s.emailService.notificationEmailService != nil {
|
||||
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventOpsScheduledReport,
|
||||
RecipientEmail: addr,
|
||||
RecipientName: emailRecipientName(addr),
|
||||
SourceType: "ops_scheduled_report",
|
||||
SourceID: opsScheduledReportDeliverySourceID(report),
|
||||
ReminderKey: now.UTC().Format("2006-01-02T15:04"),
|
||||
Variables: templateVariables,
|
||||
RawHTMLVariables: map[string]string{
|
||||
"report_html": content,
|
||||
},
|
||||
}); err == nil {
|
||||
continue
|
||||
} else if !shouldFallbackNotificationEmail(err) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
|
||||
// Ignore per-recipient failures; continue best-effort.
|
||||
continue
|
||||
@ -353,6 +372,46 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
||||
return attempts, nil
|
||||
}
|
||||
|
||||
func opsScheduledReportDeliverySourceID(report *opsScheduledReport) string {
|
||||
if report == nil {
|
||||
return "scheduled_report"
|
||||
}
|
||||
parts := []string{
|
||||
strings.TrimSpace(report.ReportType),
|
||||
strings.TrimSpace(report.Name),
|
||||
strings.TrimSpace(report.Schedule),
|
||||
}
|
||||
joined := strings.Trim(strings.Join(parts, ":"), ":")
|
||||
if joined == "" {
|
||||
return "scheduled_report"
|
||||
}
|
||||
return joined
|
||||
}
|
||||
|
||||
func opsScheduledReportEmailVariables(report *opsScheduledReport, now time.Time) map[string]string {
|
||||
end := now.UTC()
|
||||
start := end
|
||||
name := "Ops report"
|
||||
reportType := "scheduled_report"
|
||||
if report != nil {
|
||||
if strings.TrimSpace(report.Name) != "" {
|
||||
name = strings.TrimSpace(report.Name)
|
||||
}
|
||||
if strings.TrimSpace(report.ReportType) != "" {
|
||||
reportType = strings.TrimSpace(report.ReportType)
|
||||
}
|
||||
if report.TimeRange > 0 {
|
||||
start = end.Add(-report.TimeRange)
|
||||
}
|
||||
}
|
||||
return map[string]string{
|
||||
"report_name": name,
|
||||
"report_type": reportType,
|
||||
"report_start_time": start.Format(time.RFC3339),
|
||||
"report_end_time": end.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
|
||||
if s == nil || s.opsService == nil || report == nil {
|
||||
return "", fmt.Errorf("service not initialized")
|
||||
|
||||
@ -310,9 +310,87 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde
|
||||
"creditedAmount": o.Amount,
|
||||
"payAmount": o.PayAmount,
|
||||
})
|
||||
s.dispatchPaymentFulfillmentNotification(o, auditAction)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PaymentService) dispatchPaymentFulfillmentNotification(o *dbent.PaymentOrder, auditAction string) {
|
||||
if s == nil || s.notificationEmailService == nil || o == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||
defer cancel()
|
||||
var err error
|
||||
switch auditAction {
|
||||
case "RECHARGE_SUCCESS":
|
||||
err = s.sendBalanceRechargeSuccessNotification(ctx, o)
|
||||
case "SUBSCRIPTION_SUCCESS":
|
||||
err = s.sendSubscriptionPurchaseSuccessNotification(ctx, o)
|
||||
default:
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("payment fulfillment notification email failed", "order_id", o.ID, "action", auditAction, "err", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *PaymentService) sendBalanceRechargeSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
currentBalance := ""
|
||||
if s.userRepo != nil {
|
||||
if user, err := s.userRepo.GetByID(ctx, o.UserID); err == nil && user != nil {
|
||||
currentBalance = fmt.Sprintf("%.2f", user.Balance)
|
||||
}
|
||||
}
|
||||
return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventBalanceRechargeSuccess,
|
||||
RecipientEmail: o.UserEmail,
|
||||
RecipientName: firstNonEmpty(o.UserName, o.UserEmail),
|
||||
UserID: o.UserID,
|
||||
SourceType: "payment_order",
|
||||
SourceID: strconv.FormatInt(o.ID, 10),
|
||||
Variables: map[string]string{
|
||||
"recharge_amount": fmt.Sprintf("%.2f", o.Amount),
|
||||
"current_balance": currentBalance,
|
||||
"order_id": strconv.FormatInt(o.ID, 10),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PaymentService) sendSubscriptionPurchaseSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||
variables := map[string]string{
|
||||
"subscription_group": "Subscription",
|
||||
"subscription_days": "",
|
||||
"expiry_time": "",
|
||||
"order_id": strconv.FormatInt(o.ID, 10),
|
||||
}
|
||||
if o.SubscriptionDays != nil {
|
||||
variables["subscription_days"] = strconv.Itoa(*o.SubscriptionDays)
|
||||
}
|
||||
if o.SubscriptionGroupID != nil {
|
||||
if s.groupRepo != nil {
|
||||
if group, err := s.groupRepo.GetByID(ctx, *o.SubscriptionGroupID); err == nil && group != nil && strings.TrimSpace(group.Name) != "" {
|
||||
variables["subscription_group"] = group.Name
|
||||
}
|
||||
}
|
||||
if s.subscriptionSvc != nil {
|
||||
if sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID); err == nil && sub != nil {
|
||||
variables["expiry_time"] = sub.ExpiresAt.Format("2006-01-02 15:04")
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventSubscriptionPurchaseSuccess,
|
||||
RecipientEmail: o.UserEmail,
|
||||
RecipientName: firstNonEmpty(o.UserName, o.UserEmail),
|
||||
UserID: o.UserID,
|
||||
SourceType: "payment_order",
|
||||
SourceID: strconv.FormatInt(o.ID, 10),
|
||||
Variables: variables,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
|
||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||
if err != nil {
|
||||
|
||||
@ -48,6 +48,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
|
||||
if user.Status != payment.EntityStatusActive {
|
||||
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
|
||||
}
|
||||
if s.notificationEmailService != nil {
|
||||
s.notificationEmailService.RememberRecipientLocale(ctx, req.UserID, user.Email, req.Locale)
|
||||
}
|
||||
orderAmount := req.Amount
|
||||
limitAmount := req.Amount
|
||||
if plan != nil {
|
||||
|
||||
@ -83,6 +83,7 @@ type CreateOrderRequest struct {
|
||||
PaymentSource string
|
||||
OrderType string
|
||||
PlanID int64
|
||||
Locale string
|
||||
}
|
||||
|
||||
type CreateOrderResponse struct {
|
||||
@ -174,18 +175,19 @@ type TopUserStat struct {
|
||||
// --- Service ---
|
||||
|
||||
type PaymentService struct {
|
||||
providerMu sync.Mutex
|
||||
providersLoaded bool
|
||||
entClient *dbent.Client
|
||||
registry *payment.Registry
|
||||
loadBalancer payment.LoadBalancer
|
||||
redeemService *RedeemService
|
||||
subscriptionSvc *SubscriptionService
|
||||
configService *PaymentConfigService
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
resumeService *PaymentResumeService
|
||||
affiliateService *AffiliateService
|
||||
providerMu sync.Mutex
|
||||
providersLoaded bool
|
||||
entClient *dbent.Client
|
||||
registry *payment.Registry
|
||||
loadBalancer payment.LoadBalancer
|
||||
redeemService *RedeemService
|
||||
subscriptionSvc *SubscriptionService
|
||||
configService *PaymentConfigService
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
resumeService *PaymentResumeService
|
||||
affiliateService *AffiliateService
|
||||
notificationEmailService *NotificationEmailService
|
||||
}
|
||||
|
||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
|
||||
@ -194,6 +196,10 @@ func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, load
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *PaymentService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||
s.notificationEmailService = notificationEmailService
|
||||
}
|
||||
|
||||
// --- Provider Registry ---
|
||||
|
||||
// EnsureProviders lazily initializes the provider registry on first call.
|
||||
|
||||
@ -128,6 +128,19 @@ const antigravityUserAgentVersionCacheTTL = 60 * time.Second
|
||||
const antigravityUserAgentVersionErrorTTL = 5 * time.Second
|
||||
const antigravityUserAgentVersionDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultOpenAICodexUserAgent OpenAI Codex 默认 User-Agent(用于规避 Cloudflare 对浏览器 UA 的质询)
|
||||
const DefaultOpenAICodexUserAgent = "codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)"
|
||||
|
||||
// cachedOpenAICodexUserAgent 缓存 OpenAI Codex UA(进程内缓存,60s TTL)
|
||||
type cachedOpenAICodexUserAgent struct {
|
||||
value string
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
const openAICodexUserAgentCacheTTL = 60 * time.Second
|
||||
const openAICodexUserAgentErrorTTL = 5 * time.Second
|
||||
const openAICodexUserAgentDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||
type DefaultSubscriptionGroupReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
@ -148,6 +161,8 @@ type SettingService struct {
|
||||
webSearchManagerBuilder WebSearchManagerBuilder
|
||||
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
|
||||
antigravityUAVersionSF singleflight.Group
|
||||
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
|
||||
openAICodexUASF singleflight.Group
|
||||
}
|
||||
|
||||
type ProviderDefaultGrantSettings struct {
|
||||
@ -907,6 +922,55 @@ func (s *SettingService) GetAntigravityUserAgentVersion(ctx context.Context) str
|
||||
return fallback
|
||||
}
|
||||
|
||||
// GetOpenAICodexUserAgent 返回 OpenAI Codex 上游请求使用的 User-Agent。
|
||||
// 后台设置优先;为空时回退到内置默认值。
|
||||
func (s *SettingService) GetOpenAICodexUserAgent(ctx context.Context) string {
|
||||
fallback := DefaultOpenAICodexUserAgent
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return fallback
|
||||
}
|
||||
if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value
|
||||
}
|
||||
}
|
||||
|
||||
result, _, _ := s.openAICodexUASF.Do("openai_codex_user_agent", func() (any, error) {
|
||||
if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value, nil
|
||||
}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAICodexUserAgentDBTimeout)
|
||||
defer cancel()
|
||||
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpenAICodexUserAgent)
|
||||
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||
slog.Warn("failed to get openai codex user agent setting", "error", err)
|
||||
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
|
||||
value: fallback,
|
||||
expiresAt: time.Now().Add(openAICodexUserAgentErrorTTL).UnixNano(),
|
||||
})
|
||||
return fallback, nil
|
||||
}
|
||||
ua := strings.TrimSpace(value)
|
||||
if ua == "" {
|
||||
ua = fallback
|
||||
}
|
||||
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
|
||||
value: ua,
|
||||
expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(),
|
||||
})
|
||||
return ua, nil
|
||||
})
|
||||
if ua, ok := result.(string); ok && ua != "" {
|
||||
return ua
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
||||
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
||||
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||
@ -1706,6 +1770,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
||||
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
|
||||
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
||||
updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent)
|
||||
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
||||
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
||||
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
||||
@ -1788,6 +1853,15 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
||||
version: antigravityUserAgentVersion,
|
||||
expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(),
|
||||
})
|
||||
s.openAICodexUASF.Forget("openai_codex_user_agent")
|
||||
codexUA := strings.TrimSpace(settings.OpenAICodexUserAgent)
|
||||
if codexUA == "" {
|
||||
codexUA = DefaultOpenAICodexUserAgent
|
||||
}
|
||||
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
|
||||
value: codexUA,
|
||||
expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(),
|
||||
})
|
||||
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
||||
@ -2529,6 +2603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
||||
SettingKeyRewriteMessageCacheControl: strconv.FormatBool(s.defaultRewriteMessageCacheControl()),
|
||||
SettingKeyAntigravityUserAgentVersion: "",
|
||||
SettingKeyOpenAICodexUserAgent: "",
|
||||
SettingPaymentVisibleMethodAlipaySource: "",
|
||||
SettingPaymentVisibleMethodWxpaySource: "",
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||
@ -3041,6 +3116,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl()
|
||||
}
|
||||
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
|
||||
result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent])
|
||||
|
||||
// Web search emulation: quick enabled check from the JSON config
|
||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||
|
||||
@ -193,6 +193,7 @@ type SystemSettings struct {
|
||||
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false)
|
||||
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
|
||||
OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent;空值使用内置默认
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
||||
|
||||
@ -2,18 +2,23 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// SubscriptionExpiryService periodically updates expired subscription status.
|
||||
type SubscriptionExpiryService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
userSubRepo UserSubscriptionRepository
|
||||
notificationEmailService *NotificationEmailService
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
|
||||
@ -24,6 +29,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||
s.notificationEmailService = notificationEmailService
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) Start() {
|
||||
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
|
||||
return
|
||||
@ -68,4 +77,50 @@ func (s *SubscriptionExpiryService) runOnce() {
|
||||
if updated > 0 {
|
||||
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
|
||||
}
|
||||
s.sendExpiryReminders(ctx)
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
|
||||
if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil {
|
||||
return
|
||||
}
|
||||
for page := 1; ; page++ {
|
||||
subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc")
|
||||
if err != nil {
|
||||
log.Printf("[SubscriptionExpiry] List active subscriptions for reminder failed: %v", err)
|
||||
return
|
||||
}
|
||||
for i := range subs {
|
||||
s.sendExpiryReminderIfDue(ctx, &subs[i])
|
||||
}
|
||||
if pag == nil || page >= pag.Pages || len(subs) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) {
|
||||
if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" {
|
||||
return
|
||||
}
|
||||
daysRemaining := sub.DaysRemaining()
|
||||
if daysRemaining != 7 && daysRemaining != 3 && daysRemaining != 1 {
|
||||
return
|
||||
}
|
||||
if err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventSubscriptionExpiryReminder,
|
||||
RecipientEmail: sub.User.Email,
|
||||
RecipientName: firstNonEmpty(sub.User.Username, sub.User.Email),
|
||||
UserID: sub.UserID,
|
||||
SourceType: "user_subscription",
|
||||
SourceID: strconv.FormatInt(sub.ID, 10),
|
||||
ReminderKey: fmt.Sprintf("%dd", daysRemaining),
|
||||
Variables: map[string]string{
|
||||
"subscription_group": sub.Group.Name,
|
||||
"expiry_time": sub.ExpiresAt.Format("2006-01-02 15:04"),
|
||||
"days_remaining": strconv.Itoa(daysRemaining),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Printf("[SubscriptionExpiry] Send expiry reminder failed: subscription=%d user=%d err=%v", sub.ID, sub.UserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,7 +517,7 @@ func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMe
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
|
||||
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64, locale ...string) error {
|
||||
// Check if email verification is enabled
|
||||
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
|
||||
@ -533,5 +533,5 @@ func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
|
||||
siteName := s.settingService.GetSiteName(ctx)
|
||||
|
||||
// Send verification code via queue
|
||||
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
|
||||
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName, firstEmailLocale(locale))
|
||||
}
|
||||
|
||||
@ -324,6 +324,30 @@ func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64,
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetAPIKeyDailyUsage returns daily usage stats for a user's API key.
|
||||
func (s *UsageService) GetAPIKeyDailyUsage(ctx context.Context, userID, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.APIKeyDailyUsagePoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, "day", userID, apiKeyID, 0, 0, "", nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key daily usage: %w", err)
|
||||
}
|
||||
|
||||
points := make([]usagestats.APIKeyDailyUsagePoint, 0, len(trend))
|
||||
for _, row := range trend {
|
||||
points = append(points, usagestats.APIKeyDailyUsagePoint{
|
||||
Date: row.Date,
|
||||
Requests: row.Requests,
|
||||
InputTokens: row.InputTokens,
|
||||
OutputTokens: row.OutputTokens,
|
||||
CacheReadTokens: row.CacheReadTokens,
|
||||
CacheWriteTokens: row.CacheCreationTokens,
|
||||
TotalTokens: row.TotalTokens,
|
||||
Cost: row.Cost,
|
||||
ActualCost: row.ActualCost,
|
||||
})
|
||||
}
|
||||
return points, nil
|
||||
}
|
||||
|
||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||
|
||||
@ -1122,7 +1122,7 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
||||
}
|
||||
|
||||
// SendNotifyEmailCode sends a verification code to the extra notification email.
|
||||
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
|
||||
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache, locale ...string) error {
|
||||
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1134,7 +1134,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
|
||||
|
||||
// Send email first — if SMTP fails, don't write cache or increment counters,
|
||||
// so the user is not locked out by cooldown/rate-limit for a code they never received.
|
||||
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
|
||||
if err := s.sendNotifyVerifyEmail(ctx, emailService, userID, email, code, firstEmailLocale(locale)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1180,13 +1180,33 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
|
||||
}
|
||||
|
||||
// sendNotifyVerifyEmail builds and sends the verification email.
|
||||
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
|
||||
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, userID int64, email, code, locale string) error {
|
||||
siteName := "Sub2API"
|
||||
if s.settingRepo != nil {
|
||||
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
|
||||
siteName = name
|
||||
}
|
||||
}
|
||||
if emailService.notificationEmailService != nil {
|
||||
if err := emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||
Event: NotificationEmailEventNotificationEmailVerifyCode,
|
||||
Locale: locale,
|
||||
RecipientEmail: email,
|
||||
RecipientName: emailRecipientName(email),
|
||||
UserID: userID,
|
||||
Variables: map[string]string{
|
||||
"verification_code": code,
|
||||
"expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)),
|
||||
},
|
||||
}); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
if !shouldFallbackNotificationEmail(err) {
|
||||
return err
|
||||
}
|
||||
slog.Warn("template notification email verification failed; falling back to built-in body", "recipient_hash", notificationEmailHash(email), "err", err.Error())
|
||||
}
|
||||
}
|
||||
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
|
||||
body := buildNotifyVerifyEmailBody(code, siteName)
|
||||
return emailService.SendEmail(ctx, email, subject, body)
|
||||
|
||||
@ -154,8 +154,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
|
||||
}
|
||||
|
||||
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
|
||||
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService {
|
||||
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
|
||||
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
|
||||
svc.SetNotificationEmailService(notificationEmailService)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@ -484,6 +485,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideOpsCleanupService,
|
||||
ProvideOpsScheduledReportService,
|
||||
NewEmailService,
|
||||
NewNotificationEmailService,
|
||||
ProvideEmailQueueService,
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
@ -520,7 +522,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewContentModerationService,
|
||||
NewAffiliateService,
|
||||
ProvidePaymentConfigService,
|
||||
NewPaymentService,
|
||||
ProvidePaymentService,
|
||||
ProvidePaymentOrderExpiryService,
|
||||
ProvideBalanceNotifyService,
|
||||
ProvideWindsurfAuthService,
|
||||
@ -648,8 +650,17 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
|
||||
}
|
||||
|
||||
// ProvideBalanceNotifyService creates BalanceNotifyService
|
||||
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
|
||||
return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
|
||||
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository, notificationEmailService *NotificationEmailService) *BalanceNotifyService {
|
||||
svc := NewBalanceNotifyService(emailService, settingRepo, accountRepo)
|
||||
svc.SetNotificationEmailService(notificationEmailService)
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvidePaymentService creates PaymentService and attaches notification email delivery.
|
||||
func ProvidePaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService, notificationEmailService *NotificationEmailService) *PaymentService {
|
||||
svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo, affiliateService)
|
||||
svc.SetNotificationEmailService(notificationEmailService)
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
|
||||
|
||||
@ -5,6 +5,45 @@
|
||||
import { config } from '@vue/test-utils'
|
||||
import { vi } from 'vitest'
|
||||
|
||||
function createMemoryStorage(): Storage {
|
||||
const values = new Map<string, string>()
|
||||
|
||||
return {
|
||||
get length() {
|
||||
return values.size
|
||||
},
|
||||
clear() {
|
||||
values.clear()
|
||||
},
|
||||
getItem(key: string) {
|
||||
return values.has(key) ? values.get(key)! : null
|
||||
},
|
||||
key(index: number) {
|
||||
return Array.from(values.keys())[index] ?? null
|
||||
},
|
||||
removeItem(key: string) {
|
||||
values.delete(key)
|
||||
},
|
||||
setItem(key: string, value: string) {
|
||||
values.set(key, String(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof globalThis.localStorage === 'undefined' || typeof globalThis.localStorage.getItem !== 'function') {
|
||||
Object.defineProperty(globalThis, 'localStorage', {
|
||||
configurable: true,
|
||||
value: createMemoryStorage()
|
||||
})
|
||||
}
|
||||
|
||||
if (typeof window !== 'undefined' && typeof window.localStorage.getItem !== 'function') {
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
configurable: true,
|
||||
value: globalThis.localStorage
|
||||
})
|
||||
}
|
||||
|
||||
// Mock requestIdleCallback (Safari < 15 不支持)
|
||||
if (typeof globalThis.requestIdleCallback === 'undefined') {
|
||||
globalThis.requestIdleCallback = ((callback: IdleRequestCallback) => {
|
||||
|
||||
@ -505,6 +505,7 @@ export interface SystemSettings {
|
||||
enable_anthropic_cache_ttl_1h_injection: boolean;
|
||||
rewrite_message_cache_control: boolean;
|
||||
antigravity_user_agent_version: string;
|
||||
openai_codex_user_agent: string;
|
||||
web_search_emulation_enabled?: boolean;
|
||||
|
||||
// Payment configuration
|
||||
@ -726,6 +727,7 @@ export interface UpdateSettingsRequest {
|
||||
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
||||
rewrite_message_cache_control?: boolean;
|
||||
antigravity_user_agent_version?: string;
|
||||
openai_codex_user_agent?: string;
|
||||
// Payment configuration
|
||||
payment_enabled?: boolean;
|
||||
risk_control_enabled?: boolean;
|
||||
@ -854,6 +856,105 @@ export async function sendTestEmail(
|
||||
return data;
|
||||
}
|
||||
|
||||
// ==================== Email Template Settings ====================
|
||||
|
||||
export interface EmailTemplateOption {
|
||||
value: string;
|
||||
label?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export type EmailTemplateEventOption = string | EmailTemplateOption;
|
||||
|
||||
export interface EmailTemplateSummary {
|
||||
event: string;
|
||||
locale: string;
|
||||
subject: string;
|
||||
is_custom?: boolean;
|
||||
updated_at?: string;
|
||||
}
|
||||
|
||||
export interface EmailTemplateListResponse {
|
||||
events: EmailTemplateEventOption[];
|
||||
locales: string[];
|
||||
templates?: EmailTemplateSummary[];
|
||||
placeholders?: string[];
|
||||
}
|
||||
|
||||
export interface EmailTemplateDetail {
|
||||
event: string;
|
||||
locale: string;
|
||||
subject: string;
|
||||
html: string;
|
||||
is_custom?: boolean;
|
||||
updated_at?: string;
|
||||
placeholders?: string[];
|
||||
}
|
||||
|
||||
export interface UpdateEmailTemplateRequest {
|
||||
subject: string;
|
||||
html: string;
|
||||
}
|
||||
|
||||
export interface PreviewEmailTemplateRequest extends UpdateEmailTemplateRequest {
|
||||
event: string;
|
||||
locale: string;
|
||||
}
|
||||
|
||||
export interface EmailTemplatePreviewResponse {
|
||||
subject: string;
|
||||
html: string;
|
||||
}
|
||||
|
||||
export async function getEmailTemplates(): Promise<EmailTemplateListResponse> {
|
||||
const { data } = await apiClient.get<EmailTemplateListResponse>(
|
||||
"/admin/settings/email-templates",
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function getEmailTemplate(
|
||||
event: string,
|
||||
locale: string,
|
||||
): Promise<EmailTemplateDetail> {
|
||||
const { data } = await apiClient.get<EmailTemplateDetail>(
|
||||
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function updateEmailTemplate(
|
||||
event: string,
|
||||
locale: string,
|
||||
request: UpdateEmailTemplateRequest,
|
||||
): Promise<EmailTemplateDetail> {
|
||||
const { data } = await apiClient.put<EmailTemplateDetail>(
|
||||
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`,
|
||||
request,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function restoreOfficialEmailTemplate(
|
||||
event: string,
|
||||
locale: string,
|
||||
): Promise<EmailTemplateDetail> {
|
||||
const { data } = await apiClient.post<EmailTemplateDetail>(
|
||||
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}/restore-official`,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function previewEmailTemplate(
|
||||
request: PreviewEmailTemplateRequest,
|
||||
): Promise<EmailTemplatePreviewResponse> {
|
||||
const { data } = await apiClient.post<EmailTemplatePreviewResponse>(
|
||||
"/admin/settings/email-template-preview",
|
||||
request,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Admin API Key status response
|
||||
*/
|
||||
@ -1160,6 +1261,11 @@ export const settingsAPI = {
|
||||
updateSettings,
|
||||
testSmtpConnection,
|
||||
sendTestEmail,
|
||||
getEmailTemplates,
|
||||
getEmailTemplate,
|
||||
updateEmailTemplate,
|
||||
restoreOfficialEmailTemplate,
|
||||
previewEmailTemplate,
|
||||
getAdminApiKey,
|
||||
regenerateAdminApiKey,
|
||||
deleteAdminApiKey,
|
||||
|
||||
@ -69,6 +69,25 @@ export interface ModelStatsResponse {
|
||||
end_date: string
|
||||
}
|
||||
|
||||
export interface ApiKeyDailyUsagePoint {
|
||||
date: string
|
||||
requests: number
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
cache_read_tokens: number
|
||||
cache_write_tokens: number
|
||||
total_tokens: number
|
||||
cost: number
|
||||
actual_cost: number
|
||||
}
|
||||
|
||||
export interface ApiKeyDailyUsageResponse {
|
||||
items: ApiKeyDailyUsagePoint[]
|
||||
days: number
|
||||
start_date: string
|
||||
end_date: string
|
||||
}
|
||||
|
||||
/**
|
||||
* List usage logs with optional filters
|
||||
* @param page - Page number (default: 1)
|
||||
@ -234,6 +253,23 @@ export async function getDashboardModels(params?: {
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get daily usage details for one API key owned by the current user.
|
||||
* @param apiKeyId - API key ID
|
||||
* @param days - Number of days to include (1-90)
|
||||
* @returns Daily usage detail rows
|
||||
*/
|
||||
export async function getMyApiKeyDailyUsage(
|
||||
apiKeyId: number,
|
||||
days: number = 30
|
||||
): Promise<ApiKeyDailyUsageResponse> {
|
||||
const { data } = await apiClient.get<ApiKeyDailyUsageResponse>(
|
||||
`/user/api-keys/${apiKeyId}/usage/daily`,
|
||||
{ params: { days } }
|
||||
)
|
||||
return data
|
||||
}
|
||||
|
||||
export interface BatchApiKeyUsageStats {
|
||||
api_key_id: number
|
||||
today_actual_cost: number
|
||||
@ -279,6 +315,7 @@ export const usageAPI = {
|
||||
getDashboardStats,
|
||||
getDashboardTrend,
|
||||
getDashboardModels,
|
||||
getMyApiKeyDailyUsage,
|
||||
getDashboardApiKeysUsage
|
||||
}
|
||||
|
||||
|
||||
@ -123,19 +123,23 @@ export default {
|
||||
dateRangeToday: 'Today',
|
||||
dateRange7d: '7 Days',
|
||||
dateRange30d: '30 Days',
|
||||
dateRange90d: '90 Days',
|
||||
dateRangeCustom: 'Custom',
|
||||
apply: 'Apply',
|
||||
used: 'Used',
|
||||
detailInfo: 'Detail Information',
|
||||
tokenStats: 'Token Statistics',
|
||||
dailyDetail: 'Daily Detail',
|
||||
modelStats: 'Model Usage Statistics',
|
||||
// Table headers
|
||||
date: 'Date',
|
||||
model: 'Model',
|
||||
requests: 'Requests',
|
||||
inputTokens: 'Input Tokens',
|
||||
outputTokens: 'Output Tokens',
|
||||
cacheCreationTokens: 'Cache Creation',
|
||||
cacheReadTokens: 'Cache Read',
|
||||
cacheWriteTokens: 'Cache Write',
|
||||
totalTokens: 'Total Tokens',
|
||||
cost: 'Cost',
|
||||
// Status
|
||||
@ -179,6 +183,7 @@ export default {
|
||||
querySuccess: 'Query successful',
|
||||
queryFailed: 'Query failed',
|
||||
queryFailedRetry: 'Query failed, please try again later',
|
||||
noDailyUsage: 'No daily usage data',
|
||||
},
|
||||
|
||||
// Setup Wizard
|
||||
@ -4176,6 +4181,22 @@ export default {
|
||||
},
|
||||
userPrefix: 'User #{id}',
|
||||
exportCsv: 'Export CSV',
|
||||
batchUpdate: 'Batch Update',
|
||||
batchUpdateTitle: 'Batch Update Redeem Codes',
|
||||
selectedCount: '{count} redeem code(s) selected',
|
||||
clearSelection: 'Clear selection',
|
||||
selectCodesFirst: 'Select redeem codes first',
|
||||
noBatchFieldsSelected: 'Select at least one field to update',
|
||||
batchUpdateSuccess: 'Updated {count} redeem code(s)',
|
||||
failedToBatchUpdate: 'Failed to batch update redeem codes',
|
||||
batchFields: {
|
||||
status: 'Status',
|
||||
expiresAt: 'Expires At',
|
||||
notes: 'Notes',
|
||||
group: 'Group'
|
||||
},
|
||||
batchNotesPlaceholder: 'Enter the new note, or leave blank to clear it',
|
||||
clearGroup: 'Clear group',
|
||||
deleteAllUnused: 'Delete All Unused Codes',
|
||||
deleteCode: 'Delete Redeem Code',
|
||||
deleteCodeConfirm:
|
||||
@ -5515,6 +5536,9 @@ export default {
|
||||
antigravityUserAgentVersion: 'Antigravity UA Version',
|
||||
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
||||
antigravityUserAgentVersionHint: 'Leave empty to use ANTIGRAVITY_USER_AGENT_VERSION or the built-in default 1.23.2; when set, the admin setting takes precedence.',
|
||||
openaiCodexUserAgent: 'OpenAI Codex UA',
|
||||
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
||||
openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.',
|
||||
},
|
||||
webSearchEmulation: {
|
||||
title: 'Web Search Emulation',
|
||||
@ -5854,6 +5878,36 @@ export default {
|
||||
sending: 'Sending...',
|
||||
enterRecipientHint: 'Please enter a recipient email address'
|
||||
},
|
||||
emailTemplates: {
|
||||
title: 'Email Templates',
|
||||
description: 'Customize notification email subjects and HTML content for each event and locale.',
|
||||
event: 'Event',
|
||||
locale: 'Locale',
|
||||
localeEn: 'English',
|
||||
localeZh: 'Chinese',
|
||||
subject: 'Subject',
|
||||
subjectPlaceholder: 'Enter the email subject',
|
||||
html: 'HTML Template',
|
||||
htmlPlaceholder: 'Edit the email HTML template',
|
||||
placeholders: 'Available Placeholders',
|
||||
placeholdersHelp: 'Click a placeholder to copy it. The backend replaces these values when sending emails.',
|
||||
livePreview: 'Live Preview',
|
||||
previewSecurityHint: 'Preview HTML is generated by the backend preview endpoint and displayed in a sandboxed iframe with scripts disabled.',
|
||||
preview: 'Preview / Refresh',
|
||||
previewing: 'Previewing...',
|
||||
save: 'Save Template',
|
||||
saving: 'Saving...',
|
||||
restoreOfficial: 'Restore Official',
|
||||
restoring: 'Restoring...',
|
||||
restoreConfirm: 'Restore the official template for this event and locale? Your custom version will be replaced.',
|
||||
restoreSuccess: 'Official template restored',
|
||||
saveSuccess: 'Email template saved',
|
||||
placeholderCopied: 'Placeholder copied',
|
||||
validationRequired: 'Subject and HTML template are required',
|
||||
empty: 'No email template events or locales are available yet.',
|
||||
noPreview: 'Refresh the preview to see the rendered email subject.',
|
||||
customized: 'Customized'
|
||||
},
|
||||
opsMonitoring: {
|
||||
title: 'Ops Monitoring',
|
||||
description: 'Enable ops monitoring for troubleshooting and health visibility',
|
||||
|
||||
@ -123,19 +123,23 @@ export default {
|
||||
dateRangeToday: '今日',
|
||||
dateRange7d: '7 天',
|
||||
dateRange30d: '30 天',
|
||||
dateRange90d: '90 天',
|
||||
dateRangeCustom: '自定义',
|
||||
apply: '应用',
|
||||
used: '已使用',
|
||||
detailInfo: '详细信息',
|
||||
tokenStats: 'Token 统计',
|
||||
dailyDetail: '按日明细',
|
||||
modelStats: '模型用量统计',
|
||||
// Table headers
|
||||
date: '日期',
|
||||
model: '模型',
|
||||
requests: '请求数',
|
||||
inputTokens: '输入 Tokens',
|
||||
outputTokens: '输出 Tokens',
|
||||
cacheCreationTokens: '缓存创建',
|
||||
cacheReadTokens: '缓存读取',
|
||||
cacheWriteTokens: '缓存写入',
|
||||
totalTokens: '总 Tokens',
|
||||
cost: '费用',
|
||||
// Status
|
||||
@ -179,6 +183,7 @@ export default {
|
||||
querySuccess: '查询成功',
|
||||
queryFailed: '查询失败',
|
||||
queryFailedRetry: '查询失败,请稍后重试',
|
||||
noDailyUsage: '暂无按日用量数据',
|
||||
},
|
||||
|
||||
// Setup Wizard
|
||||
@ -4310,6 +4315,22 @@ export default {
|
||||
used: '已使用',
|
||||
searchCodes: '搜索兑换码或邮箱...',
|
||||
exportCsv: '导出 CSV',
|
||||
batchUpdate: '批量修改',
|
||||
batchUpdateTitle: '批量修改兑换码',
|
||||
selectedCount: '已选择 {count} 个兑换码',
|
||||
clearSelection: '清空选择',
|
||||
selectCodesFirst: '请先选择兑换码',
|
||||
noBatchFieldsSelected: '请至少勾选一个要修改的字段',
|
||||
batchUpdateSuccess: '成功修改 {count} 个兑换码',
|
||||
failedToBatchUpdate: '批量修改兑换码失败',
|
||||
batchFields: {
|
||||
status: '状态',
|
||||
expiresAt: '过期时间',
|
||||
notes: '备注',
|
||||
group: '分组'
|
||||
},
|
||||
batchNotesPlaceholder: '输入新的备注,留空可清空备注',
|
||||
clearGroup: '清空分组',
|
||||
deleteAllUnused: '删除全部未使用',
|
||||
deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。',
|
||||
deleteAllUnusedConfirm: '确定要删除全部未使用的兑换码吗?此操作无法撤销。',
|
||||
@ -5673,6 +5694,9 @@ export default {
|
||||
antigravityUserAgentVersion: 'Antigravity UA 版本',
|
||||
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
||||
antigravityUserAgentVersionHint: '留空时使用 ANTIGRAVITY_USER_AGENT_VERSION 或内置默认值 1.23.2;填写后后台设置优先。',
|
||||
openaiCodexUserAgent: 'OpenAI Codex UA',
|
||||
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
||||
openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器(Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。',
|
||||
},
|
||||
webSearchEmulation: {
|
||||
title: 'Web Search 模拟',
|
||||
@ -6014,6 +6038,36 @@ export default {
|
||||
sending: '发送中...',
|
||||
enterRecipientHint: '请输入收件人邮箱地址'
|
||||
},
|
||||
emailTemplates: {
|
||||
title: '邮件模板',
|
||||
description: '按事件和语言自定义通知邮件主题与 HTML 内容。',
|
||||
event: '事件',
|
||||
locale: '语言',
|
||||
localeEn: '英文',
|
||||
localeZh: '中文',
|
||||
subject: '主题',
|
||||
subjectPlaceholder: '输入邮件主题',
|
||||
html: 'HTML 模板',
|
||||
htmlPlaceholder: '编辑邮件 HTML 模板',
|
||||
placeholders: '可用占位符',
|
||||
placeholdersHelp: '点击占位符可复制。后端发送邮件时会替换这些值。',
|
||||
livePreview: '实时预览',
|
||||
previewSecurityHint: '预览 HTML 由后端预览接口生成,并在禁用脚本的沙盒 iframe 中展示。',
|
||||
preview: '预览 / 刷新',
|
||||
previewing: '预览中...',
|
||||
save: '保存模板',
|
||||
saving: '保存中...',
|
||||
restoreOfficial: '恢复官方模板',
|
||||
restoring: '恢复中...',
|
||||
restoreConfirm: '确定恢复此事件和语言的官方模板吗?当前自定义版本将被替换。',
|
||||
restoreSuccess: '已恢复官方模板',
|
||||
saveSuccess: '邮件模板已保存',
|
||||
placeholderCopied: '占位符已复制',
|
||||
validationRequired: '主题和 HTML 模板不能为空',
|
||||
empty: '暂无可用的邮件模板事件或语言。',
|
||||
noPreview: '刷新预览后查看渲染后的邮件主题。',
|
||||
customized: '已自定义'
|
||||
},
|
||||
opsMonitoring: {
|
||||
title: '运维监控',
|
||||
description: '启用运维监控模块,用于排障与健康可视化',
|
||||
|
||||
@ -289,6 +289,62 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Daily Usage Table -->
|
||||
<div
|
||||
v-if="showDailyUsage"
|
||||
class="fade-up fade-up-delay-4 rounded-2xl border border-gray-200 bg-white/90 backdrop-blur-sm overflow-hidden dark:border-dark-700 dark:bg-dark-900/90"
|
||||
>
|
||||
<div class="flex flex-col gap-3 px-8 py-5 border-b border-gray-200 dark:border-dark-700 sm:flex-row sm:items-center sm:justify-between">
|
||||
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.dailyDetail') }}</h3>
|
||||
<div class="inline-flex rounded-lg border border-gray-200 bg-white p-0.5 dark:border-dark-700 dark:bg-dark-950">
|
||||
<button
|
||||
v-for="option in dailyUsageOptions"
|
||||
:key="option.value"
|
||||
@click="setDailyUsageDays(option.value)"
|
||||
class="min-w-12 rounded-md px-3 py-1.5 text-xs font-medium transition-colors"
|
||||
:class="dailyUsageDays === option.value
|
||||
? 'bg-primary-500 text-white'
|
||||
: 'text-gray-600 hover:bg-gray-100 dark:text-dark-300 dark:hover:bg-dark-800'"
|
||||
>
|
||||
{{ option.label }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="dailyUsageRows.length > 0" class="overflow-x-auto">
|
||||
<table class="w-full">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200 bg-gray-50 dark:border-dark-700 dark:bg-dark-950">
|
||||
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.date') }}</th>
|
||||
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.requests') }}</th>
|
||||
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.inputTokens') }}</th>
|
||||
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.outputTokens') }}</th>
|
||||
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheReadTokens') }}</th>
|
||||
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheWriteTokens') }}</th>
|
||||
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cost') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="row in dailyUsageRows"
|
||||
:key="row.date"
|
||||
class="border-b border-gray-100 last:border-b-0 dark:border-dark-800"
|
||||
>
|
||||
<td class="px-4 py-3 text-sm font-medium whitespace-nowrap text-gray-900 dark:text-white">{{ row.date }}</td>
|
||||
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.requests) }}</td>
|
||||
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.input_tokens) }}</td>
|
||||
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.output_tokens) }}</td>
|
||||
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.cache_read_tokens) }}</td>
|
||||
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.cache_write_tokens) }}</td>
|
||||
<td class="px-4 py-3 text-sm tabular-nums text-right font-medium text-gray-900 dark:text-white">{{ usd(row.actual_cost != null ? row.actual_cost : row.cost) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div v-else class="px-8 py-8 text-center text-sm text-gray-500 dark:text-dark-400">
|
||||
{{ t('keyUsage.noDailyUsage') }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Stats Table -->
|
||||
<div
|
||||
v-if="modelStats.length > 0"
|
||||
@ -408,6 +464,7 @@ type DateRangeKey = 'today' | '7d' | '30d' | 'custom'
|
||||
const currentRange = ref<DateRangeKey>('today')
|
||||
const customStartDate = ref('')
|
||||
const customEndDate = ref('')
|
||||
const dailyUsageDays = ref<7 | 30 | 90>(30)
|
||||
|
||||
const dateRanges = computed(() => [
|
||||
{ key: 'today' as const, label: t('keyUsage.dateRangeToday') },
|
||||
@ -416,6 +473,12 @@ const dateRanges = computed(() => [
|
||||
{ key: 'custom' as const, label: t('keyUsage.dateRangeCustom') },
|
||||
])
|
||||
|
||||
const dailyUsageOptions = computed(() => [
|
||||
{ value: 7 as const, label: t('keyUsage.dateRange7d') },
|
||||
{ value: 30 as const, label: t('keyUsage.dateRange30d') },
|
||||
{ value: 90 as const, label: t('keyUsage.dateRange90d') },
|
||||
])
|
||||
|
||||
function setDateRange(key: DateRangeKey) {
|
||||
currentRange.value = key
|
||||
if (key !== 'custom') {
|
||||
@ -426,23 +489,36 @@ function setDateRange(key: DateRangeKey) {
|
||||
function getDateParams(): string {
|
||||
const now = new Date()
|
||||
const fmt = (d: Date) => d.toISOString().split('T')[0]
|
||||
const params = new URLSearchParams()
|
||||
|
||||
if (currentRange.value === 'custom') {
|
||||
if (customStartDate.value && customEndDate.value) {
|
||||
return `start_date=${customStartDate.value}&end_date=${customEndDate.value}`
|
||||
params.set('start_date', customStartDate.value)
|
||||
params.set('end_date', customEndDate.value)
|
||||
}
|
||||
return ''
|
||||
} else {
|
||||
const end = fmt(now)
|
||||
let start: string
|
||||
switch (currentRange.value) {
|
||||
case 'today': start = end; break
|
||||
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
|
||||
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
|
||||
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
|
||||
}
|
||||
params.set('start_date', start)
|
||||
params.set('end_date', end)
|
||||
}
|
||||
params.set('days', String(dailyUsageDays.value))
|
||||
params.set('timezone', getBrowserTimezone())
|
||||
return params.toString()
|
||||
}
|
||||
|
||||
const end = fmt(now)
|
||||
let start: string
|
||||
switch (currentRange.value) {
|
||||
case 'today': start = end; break
|
||||
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
|
||||
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
|
||||
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
|
||||
function setDailyUsageDays(days: 7 | 30 | 90) {
|
||||
if (dailyUsageDays.value === days) return
|
||||
dailyUsageDays.value = days
|
||||
if (resultData.value && apiKey.value.trim()) {
|
||||
queryKey()
|
||||
}
|
||||
return `start_date=${start}&end_date=${end}`
|
||||
}
|
||||
|
||||
// ==================== Ring Animation ====================
|
||||
@ -731,6 +807,24 @@ const usageStatCells = computed<StatCell[]>(() => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const modelStats = computed<any[]>(() => resultData.value?.model_stats || [])
|
||||
|
||||
interface DailyUsageRow {
|
||||
date: string
|
||||
requests: number
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
cache_read_tokens: number
|
||||
cache_write_tokens: number
|
||||
cost: number
|
||||
actual_cost?: number
|
||||
}
|
||||
|
||||
const dailyUsageRows = computed<DailyUsageRow[]>(() => {
|
||||
const rows = resultData.value?.daily_usage
|
||||
return Array.isArray(rows) ? rows : []
|
||||
})
|
||||
|
||||
const showDailyUsage = computed(() => Boolean(resultData.value && Array.isArray(resultData.value.daily_usage)))
|
||||
|
||||
// ==================== Utility Functions ====================
|
||||
|
||||
function usd(value: number | null | undefined): string {
|
||||
@ -750,6 +844,14 @@ function formatDate(iso: string | null | undefined): string {
|
||||
return d.toLocaleDateString(loc, { year: 'numeric', month: 'long', day: 'numeric' })
|
||||
}
|
||||
|
||||
function getBrowserTimezone(): string {
|
||||
try {
|
||||
return Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC'
|
||||
} catch {
|
||||
return 'UTC'
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== API Query ====================
|
||||
|
||||
async function fetchUsage(key: string) {
|
||||
|
||||
208
frontend/src/views/__tests__/KeyUsageView.spec.ts
Normal file
208
frontend/src/views/__tests__/KeyUsageView.spec.ts
Normal file
@ -0,0 +1,208 @@
|
||||
import { describe, expect, it, beforeEach, afterEach, vi } from 'vitest'
|
||||
import { flushPromises, mount } from '@vue/test-utils'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import KeyUsageView from '../KeyUsageView.vue'
|
||||
|
||||
const { showInfo, showSuccess, showError, fetchPublicSettings } = vi.hoisted(() => ({
|
||||
showInfo: vi.fn(),
|
||||
showSuccess: vi.fn(),
|
||||
showError: vi.fn(),
|
||||
fetchPublicSettings: vi.fn(),
|
||||
}))
|
||||
|
||||
const messages: Record<string, string> = {
|
||||
'keyUsage.title': 'API Key Usage',
|
||||
'keyUsage.subtitle': 'Usage status',
|
||||
'keyUsage.placeholder': 'sk-test',
|
||||
'keyUsage.query': 'Query',
|
||||
'keyUsage.querying': 'Querying...',
|
||||
'keyUsage.privacyNote': 'Privacy note',
|
||||
'keyUsage.dateRange': 'Date Range:',
|
||||
'keyUsage.dateRangeToday': 'Today',
|
||||
'keyUsage.dateRange7d': '7 Days',
|
||||
'keyUsage.dateRange30d': '30 Days',
|
||||
'keyUsage.dateRange90d': '90 Days',
|
||||
'keyUsage.dateRangeCustom': 'Custom',
|
||||
'keyUsage.apply': 'Apply',
|
||||
'keyUsage.used': 'Used',
|
||||
'keyUsage.detailInfo': 'Detail Information',
|
||||
'keyUsage.tokenStats': 'Token Statistics',
|
||||
'keyUsage.dailyDetail': 'Daily Detail',
|
||||
'keyUsage.date': 'Date',
|
||||
'keyUsage.requests': 'Requests',
|
||||
'keyUsage.inputTokens': 'Input Tokens',
|
||||
'keyUsage.outputTokens': 'Output Tokens',
|
||||
'keyUsage.cacheReadTokens': 'Cache Read',
|
||||
'keyUsage.cacheWriteTokens': 'Cache Write',
|
||||
'keyUsage.cost': 'Cost',
|
||||
'keyUsage.quotaMode': 'Key Quota Mode',
|
||||
'keyUsage.walletBalance': 'Wallet Balance',
|
||||
'keyUsage.totalQuota': 'Total Quota',
|
||||
'keyUsage.limit5h': '5-Hour Limit',
|
||||
'keyUsage.limitDaily': 'Daily Limit',
|
||||
'keyUsage.limit7d': '7-Day Limit',
|
||||
'keyUsage.limitWeekly': 'Weekly Limit',
|
||||
'keyUsage.limitMonthly': 'Monthly Limit',
|
||||
'keyUsage.remainingQuota': 'Remaining Quota',
|
||||
'keyUsage.usedQuota': 'Used Quota',
|
||||
'keyUsage.subscriptionType': 'Subscription Type',
|
||||
'keyUsage.todayRequests': 'Today Requests',
|
||||
'keyUsage.todayInputTokens': 'Today Input',
|
||||
'keyUsage.todayOutputTokens': 'Today Output',
|
||||
'keyUsage.todayTokens': 'Today Tokens',
|
||||
'keyUsage.todayCacheCreation': 'Today Cache Creation',
|
||||
'keyUsage.todayCacheRead': 'Today Cache Read',
|
||||
'keyUsage.todayCost': 'Today Cost',
|
||||
'keyUsage.rpmTpm': 'RPM / TPM',
|
||||
'keyUsage.totalRequests': 'Total Requests',
|
||||
'keyUsage.totalInputTokens': 'Total Input',
|
||||
'keyUsage.totalOutputTokens': 'Total Output',
|
||||
'keyUsage.totalTokensLabel': 'Total Tokens',
|
||||
'keyUsage.totalCacheCreation': 'Total Cache Creation',
|
||||
'keyUsage.totalCacheRead': 'Total Cache Read',
|
||||
'keyUsage.totalCost': 'Total Cost',
|
||||
'keyUsage.avgDuration': 'Avg Duration',
|
||||
'keyUsage.querySuccess': 'Query successful',
|
||||
'keyUsage.queryFailed': 'Query failed',
|
||||
'keyUsage.queryFailedRetry': 'Query failed, please try again later',
|
||||
'home.viewDocs': 'Docs',
|
||||
'home.switchToLight': 'Light',
|
||||
'home.switchToDark': 'Dark',
|
||||
'home.footer.allRightsReserved': 'All rights reserved.',
|
||||
}
|
||||
|
||||
vi.mock('vue-i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||
return {
|
||||
...actual,
|
||||
useI18n: () => ({
|
||||
t: (key: string) => messages[key] ?? key,
|
||||
locale: { value: 'en' },
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/stores', () => ({
|
||||
useAppStore: () => ({
|
||||
cachedPublicSettings: null,
|
||||
siteName: 'Sub2API',
|
||||
siteLogo: '',
|
||||
docUrl: '',
|
||||
publicSettingsLoaded: true,
|
||||
fetchPublicSettings,
|
||||
showInfo,
|
||||
showSuccess,
|
||||
showError,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('KeyUsageView daily detail', () => {
|
||||
beforeEach(() => {
|
||||
showInfo.mockReset()
|
||||
showSuccess.mockReset()
|
||||
showError.mockReset()
|
||||
fetchPublicSettings.mockReset()
|
||||
localStorage.clear()
|
||||
|
||||
Object.defineProperty(window, 'matchMedia', {
|
||||
configurable: true,
|
||||
value: vi.fn().mockReturnValue({ matches: false }),
|
||||
})
|
||||
vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => window.setTimeout(() => cb(0), 0))
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
mode: 'quota_limited',
|
||||
isValid: true,
|
||||
status: 'active',
|
||||
quota: {
|
||||
limit: 10,
|
||||
used: 1,
|
||||
remaining: 9,
|
||||
unit: 'USD',
|
||||
},
|
||||
usage: {
|
||||
today: {
|
||||
requests: 1,
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
cache_creation_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
total_tokens: 30,
|
||||
actual_cost: 0.01,
|
||||
},
|
||||
total: {
|
||||
requests: 12,
|
||||
input_tokens: 100,
|
||||
output_tokens: 200,
|
||||
cache_creation_tokens: 10,
|
||||
cache_read_tokens: 30,
|
||||
total_tokens: 340,
|
||||
actual_cost: 0.12,
|
||||
},
|
||||
rpm: 0,
|
||||
tpm: 0,
|
||||
},
|
||||
daily_usage: [
|
||||
{
|
||||
date: '2026-05-19',
|
||||
requests: 12,
|
||||
input_tokens: 100,
|
||||
output_tokens: 200,
|
||||
cache_read_tokens: 30,
|
||||
cache_write_tokens: 10,
|
||||
total_tokens: 340,
|
||||
cost: 0.15,
|
||||
actual_cost: 0.12,
|
||||
},
|
||||
],
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('renders daily usage detail rows after a successful query', async () => {
|
||||
const wrapper = mount(KeyUsageView, {
|
||||
global: {
|
||||
stubs: {
|
||||
RouterLink: { template: '<a><slot /></a>' },
|
||||
LocaleSwitcher: true,
|
||||
Icon: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
await wrapper.find('input').setValue('sk-test-key')
|
||||
await wrapper.find('input').trigger('keydown.enter')
|
||||
await flushPromises()
|
||||
await nextTick()
|
||||
|
||||
const fetchMock = vi.mocked(fetch)
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
expect.stringContaining('/v1/usage?'),
|
||||
expect.objectContaining({
|
||||
headers: { Authorization: 'Bearer sk-test-key' },
|
||||
})
|
||||
)
|
||||
expect(String(fetchMock.mock.calls[0][0])).toContain('days=30')
|
||||
|
||||
const text = wrapper.text()
|
||||
expect(text).toContain('Daily Detail')
|
||||
expect(text).toContain('Date')
|
||||
expect(text).toContain('Cache Read')
|
||||
expect(text).toContain('Cache Write')
|
||||
expect(text).toContain('2026-05-19')
|
||||
expect(text).toContain('12')
|
||||
expect(text).toContain('100')
|
||||
expect(text).toContain('200')
|
||||
expect(text).toContain('30')
|
||||
expect(text).toContain('10')
|
||||
expect(text).toContain('$0.12')
|
||||
|
||||
wrapper.unmount()
|
||||
})
|
||||
})
|
||||
@ -3769,6 +3769,36 @@
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- OpenAI Codex UA -->
|
||||
<div>
|
||||
<label
|
||||
class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"
|
||||
>
|
||||
{{
|
||||
t(
|
||||
"admin.settings.gatewayForwarding.openaiCodexUserAgent",
|
||||
)
|
||||
}}
|
||||
</label>
|
||||
<input
|
||||
v-model="form.openai_codex_user_agent"
|
||||
type="text"
|
||||
class="input w-full font-mono text-sm"
|
||||
:placeholder="
|
||||
t(
|
||||
'admin.settings.gatewayForwarding.openaiCodexUserAgentPlaceholder',
|
||||
)
|
||||
"
|
||||
/>
|
||||
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t(
|
||||
"admin.settings.gatewayForwarding.openaiCodexUserAgentHint",
|
||||
)
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Web Search Emulation -->
|
||||
@ -6225,6 +6255,9 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<EmailTemplateEditor />
|
||||
|
||||
<!-- Balance Low Notification -->
|
||||
<div class="card">
|
||||
<div
|
||||
@ -6482,6 +6515,7 @@ import Toggle from "@/components/common/Toggle.vue";
|
||||
import ProxySelector from "@/components/common/ProxySelector.vue";
|
||||
import ImageUpload from "@/components/common/ImageUpload.vue";
|
||||
import BackupSettings from "@/views/admin/BackupView.vue";
|
||||
import EmailTemplateEditor from "@/views/admin/settings/EmailTemplateEditor.vue";
|
||||
import { useClipboard } from "@/composables/useClipboard";
|
||||
import { affiliatesAPI, type AffiliateAdminEntry, type SimpleUser as AffiliateSimpleUser } from "@/api/admin/affiliates";
|
||||
import { extractApiErrorMessage, extractI18nErrorMessage } from "@/utils/apiError";
|
||||
@ -6943,6 +6977,7 @@ const form = reactive<SettingsForm>({
|
||||
enable_anthropic_cache_ttl_1h_injection: false,
|
||||
rewrite_message_cache_control: false,
|
||||
antigravity_user_agent_version: "",
|
||||
openai_codex_user_agent: "",
|
||||
// Balance & quota notification
|
||||
balance_low_notify_enabled: false,
|
||||
balance_low_notify_threshold: 0,
|
||||
@ -8044,6 +8079,8 @@ async function saveSettings() {
|
||||
rewrite_message_cache_control: form.rewrite_message_cache_control,
|
||||
antigravity_user_agent_version:
|
||||
form.antigravity_user_agent_version?.trim() || "",
|
||||
openai_codex_user_agent:
|
||||
form.openai_codex_user_agent?.trim() || "",
|
||||
// Payment configuration
|
||||
payment_enabled: form.payment_enabled,
|
||||
risk_control_enabled: form.risk_control_enabled,
|
||||
|
||||
@ -371,6 +371,7 @@ const baseSettingsResponse = {
|
||||
enable_anthropic_cache_ttl_1h_injection: false,
|
||||
rewrite_message_cache_control: false,
|
||||
antigravity_user_agent_version: "",
|
||||
openai_codex_user_agent: "",
|
||||
payment_enabled: true,
|
||||
payment_min_amount: 1,
|
||||
payment_max_amount: 10000,
|
||||
|
||||
483
frontend/src/views/admin/settings/EmailTemplateEditor.vue
Normal file
483
frontend/src/views/admin/settings/EmailTemplateEditor.vue
Normal file
@ -0,0 +1,483 @@
|
||||
<template>
|
||||
<div class="card">
|
||||
<div
|
||||
class="flex flex-col gap-3 border-b border-gray-100 px-6 py-4 dark:border-dark-700 lg:flex-row lg:items-start lg:justify-between"
|
||||
>
|
||||
<div>
|
||||
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||
{{ t("admin.settings.emailTemplates.title") }}
|
||||
</h2>
|
||||
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t("admin.settings.emailTemplates.description") }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
type="button"
|
||||
class="btn btn-secondary btn-sm"
|
||||
:disabled="loadingTemplate || previewing || !canPreview"
|
||||
@click="refreshPreview"
|
||||
>
|
||||
{{ previewing ? t("admin.settings.emailTemplates.previewing") : t("admin.settings.emailTemplates.preview") }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="btn btn-secondary btn-sm"
|
||||
:disabled="loadingTemplate || restoring || !selectedEvent || !selectedLocale"
|
||||
@click="restoreOfficial"
|
||||
>
|
||||
{{ restoring ? t("admin.settings.emailTemplates.restoring") : t("admin.settings.emailTemplates.restoreOfficial") }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="btn btn-primary btn-sm"
|
||||
:disabled="loadingTemplate || saving || !canSave"
|
||||
@click="saveTemplate"
|
||||
>
|
||||
{{ saving ? t("admin.settings.emailTemplates.saving") : t("admin.settings.emailTemplates.save") }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="space-y-6 p-6">
|
||||
<div
|
||||
v-if="loadingList"
|
||||
class="flex items-center gap-2 text-sm text-gray-500 dark:text-gray-400"
|
||||
>
|
||||
<span
|
||||
class="h-4 w-4 animate-spin rounded-full border-b-2 border-primary-600"
|
||||
></span>
|
||||
{{ t("common.loading") }}
|
||||
</div>
|
||||
|
||||
<template v-else>
|
||||
<div class="grid grid-cols-1 gap-4 md:grid-cols-2">
|
||||
<div>
|
||||
<label class="input-label" for="email-template-event">
|
||||
{{ t("admin.settings.emailTemplates.event") }}
|
||||
</label>
|
||||
<select
|
||||
id="email-template-event"
|
||||
v-model="selectedEvent"
|
||||
class="input"
|
||||
:disabled="loadingTemplate || eventOptions.length === 0"
|
||||
>
|
||||
<option
|
||||
v-for="option in eventOptions"
|
||||
:key="option.value"
|
||||
:value="option.value"
|
||||
>
|
||||
{{ option.label || option.value }}
|
||||
</option>
|
||||
</select>
|
||||
<p v-if="selectedEventDescription" class="input-hint">
|
||||
{{ selectedEventDescription }}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label" for="email-template-locale">
|
||||
{{ t("admin.settings.emailTemplates.locale") }}
|
||||
</label>
|
||||
<select
|
||||
id="email-template-locale"
|
||||
v-model="selectedLocale"
|
||||
class="input"
|
||||
:disabled="loadingTemplate || localeOptions.length === 0"
|
||||
>
|
||||
<option
|
||||
v-for="localeOption in localeOptions"
|
||||
:key="localeOption"
|
||||
:value="localeOption"
|
||||
>
|
||||
{{ formatLocale(localeOption) }}
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="!eventOptions.length || !localeOptions.length"
|
||||
class="rounded-lg border border-amber-200 bg-amber-50 p-4 text-sm text-amber-700 dark:border-amber-800 dark:bg-amber-900/20 dark:text-amber-300"
|
||||
>
|
||||
{{ t("admin.settings.emailTemplates.empty") }}
|
||||
</div>
|
||||
|
||||
<div v-else class="grid grid-cols-1 gap-6 xl:grid-cols-2">
|
||||
<div class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label" for="email-template-subject">
|
||||
{{ t("admin.settings.emailTemplates.subject") }}
|
||||
</label>
|
||||
<input
|
||||
id="email-template-subject"
|
||||
v-model="subject"
|
||||
type="text"
|
||||
class="input"
|
||||
:disabled="loadingTemplate"
|
||||
:placeholder="t('admin.settings.emailTemplates.subjectPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label class="input-label" for="email-template-html">
|
||||
{{ t("admin.settings.emailTemplates.html") }}
|
||||
</label>
|
||||
<textarea
|
||||
id="email-template-html"
|
||||
v-model="html"
|
||||
rows="18"
|
||||
class="input min-h-[28rem] resize-y font-mono text-sm leading-6"
|
||||
:disabled="loadingTemplate"
|
||||
:placeholder="t('admin.settings.emailTemplates.htmlPlaceholder')"
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800/60"
|
||||
>
|
||||
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t("admin.settings.emailTemplates.placeholders") }}
|
||||
</div>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t("admin.settings.emailTemplates.placeholdersHelp") }}
|
||||
</p>
|
||||
<div class="mt-3 flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="placeholder in placeholderList"
|
||||
:key="placeholder"
|
||||
type="button"
|
||||
class="rounded-full border border-gray-200 bg-white px-3 py-1 font-mono text-xs text-gray-700 transition-colors hover:border-primary-300 hover:text-primary-600 dark:border-dark-600 dark:bg-dark-700 dark:text-gray-200 dark:hover:border-primary-500 dark:hover:text-primary-300"
|
||||
@click="copyPlaceholder(placeholder)"
|
||||
>
|
||||
{{ placeholder }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
class="rounded-lg border border-gray-200 bg-white dark:border-dark-700 dark:bg-dark-800"
|
||||
>
|
||||
<div
|
||||
class="flex items-center justify-between border-b border-gray-100 px-4 py-3 dark:border-dark-700"
|
||||
>
|
||||
<div>
|
||||
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t("admin.settings.emailTemplates.livePreview") }}
|
||||
</div>
|
||||
<div class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ previewSubject || t("admin.settings.emailTemplates.noPreview") }}
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
v-if="isCustomTemplate"
|
||||
class="rounded-full bg-primary-50 px-2.5 py-1 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
|
||||
>
|
||||
{{ t("admin.settings.emailTemplates.customized") }}
|
||||
</span>
|
||||
</div>
|
||||
<div class="bg-gray-100 p-3 dark:bg-dark-900">
|
||||
<iframe
|
||||
class="h-[36rem] w-full rounded-md border border-gray-200 bg-white dark:border-dark-700"
|
||||
sandbox=""
|
||||
:srcdoc="previewHtml"
|
||||
:title="t('admin.settings.emailTemplates.livePreview')"
|
||||
></iframe>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t("admin.settings.emailTemplates.previewSecurityHint") }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, ref, watch } from "vue";
|
||||
import { useI18n } from "vue-i18n";
|
||||
import { adminAPI } from "@/api";
|
||||
import type {
|
||||
EmailTemplateEventOption,
|
||||
EmailTemplateOption,
|
||||
} from "@/api/admin/settings";
|
||||
import { useAppStore } from "@/stores";
|
||||
import { extractApiErrorMessage } from "@/utils/apiError";
|
||||
|
||||
const { t, locale } = useI18n();
|
||||
const appStore = useAppStore();
|
||||
|
||||
const fallbackPlaceholders = [
|
||||
"{{site_name}}",
|
||||
"{{recipient_name}}",
|
||||
"{{recipient_email}}",
|
||||
"{{verification_code}}",
|
||||
"{{expires_in_minutes}}",
|
||||
"{{reset_url}}",
|
||||
"{{subscription_group}}",
|
||||
"{{subscription_days}}",
|
||||
"{{expiry_time}}",
|
||||
"{{days_remaining}}",
|
||||
"{{current_balance}}",
|
||||
"{{threshold}}",
|
||||
"{{recharge_url}}",
|
||||
"{{recharge_amount}}",
|
||||
"{{order_id}}",
|
||||
"{{unsubscribe_url}}",
|
||||
"{{account_id}}",
|
||||
"{{account_name}}",
|
||||
"{{platform}}",
|
||||
"{{quota_dimension}}",
|
||||
"{{quota_used}}",
|
||||
"{{quota_limit}}",
|
||||
"{{quota_remaining}}",
|
||||
"{{quota_threshold}}",
|
||||
"{{triggered_at}}",
|
||||
"{{group_name}}",
|
||||
"{{moderation_category}}",
|
||||
"{{moderation_score}}",
|
||||
"{{violation_count}}",
|
||||
"{{ban_threshold}}",
|
||||
"{{rule_name}}",
|
||||
"{{severity}}",
|
||||
"{{alert_status}}",
|
||||
"{{metric_type}}",
|
||||
"{{operator}}",
|
||||
"{{metric_value}}",
|
||||
"{{threshold_value}}",
|
||||
"{{alert_description}}",
|
||||
"{{report_name}}",
|
||||
"{{report_type}}",
|
||||
"{{report_start_time}}",
|
||||
"{{report_end_time}}",
|
||||
"{{report_html}}",
|
||||
];
|
||||
|
||||
const loadingList = ref(true);
|
||||
const loadingTemplate = ref(false);
|
||||
const saving = ref(false);
|
||||
const previewing = ref(false);
|
||||
const restoring = ref(false);
|
||||
const eventOptions = ref<EmailTemplateOption[]>([]);
|
||||
const localeOptions = ref<string[]>([]);
|
||||
const selectedEvent = ref("");
|
||||
const selectedLocale = ref("");
|
||||
const subject = ref("");
|
||||
const html = ref("");
|
||||
const isCustomTemplate = ref(false);
|
||||
const placeholders = ref<string[]>([]);
|
||||
const previewSubject = ref("");
|
||||
const previewHtml = ref("");
|
||||
const initializingSelection = ref(false);
|
||||
|
||||
function normalizeEventOption(option: EmailTemplateEventOption): EmailTemplateOption {
|
||||
if (typeof option === "string") {
|
||||
return { value: option };
|
||||
}
|
||||
return option;
|
||||
}
|
||||
|
||||
const selectedEventDescription = computed(() => {
|
||||
return (
|
||||
eventOptions.value.find((option) => option.value === selectedEvent.value)
|
||||
?.description || ""
|
||||
);
|
||||
});
|
||||
|
||||
const placeholderList = computed(() => {
|
||||
const combined = [...placeholders.value, ...fallbackPlaceholders];
|
||||
return Array.from(
|
||||
new Set(
|
||||
combined
|
||||
.map((item) => formatPlaceholder(item))
|
||||
.filter((item) => item.length > 0),
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
function formatPlaceholder(placeholder: string): string {
|
||||
const trimmed = placeholder.trim();
|
||||
if (!trimmed) return "";
|
||||
if (trimmed.startsWith("{{") && trimmed.endsWith("}}")) return trimmed;
|
||||
return `{{${trimmed}}}`;
|
||||
}
|
||||
|
||||
const canSave = computed(
|
||||
() =>
|
||||
Boolean(selectedEvent.value && selectedLocale.value) &&
|
||||
subject.value.trim().length > 0 &&
|
||||
html.value.trim().length > 0,
|
||||
);
|
||||
|
||||
const canPreview = computed(
|
||||
() => Boolean(selectedEvent.value && selectedLocale.value) && html.value.trim().length > 0,
|
||||
);
|
||||
|
||||
function formatLocale(locale: string): string {
|
||||
const lower = locale.toLowerCase();
|
||||
if (lower === "zh" || lower.startsWith("zh-")) {
|
||||
return t("admin.settings.emailTemplates.localeZh");
|
||||
}
|
||||
if (lower === "en" || lower.startsWith("en-")) {
|
||||
return t("admin.settings.emailTemplates.localeEn");
|
||||
}
|
||||
return locale;
|
||||
}
|
||||
|
||||
function selectInitialLocale(locales: string[]): string {
|
||||
const currentLocale = locale.value.toLowerCase();
|
||||
const exactMatch = locales.find(
|
||||
(availableLocale) => availableLocale.toLowerCase() === currentLocale,
|
||||
);
|
||||
if (exactMatch) return exactMatch;
|
||||
|
||||
const currentLanguage = currentLocale.split("-")[0];
|
||||
const languageMatch = locales.find(
|
||||
(availableLocale) => availableLocale.toLowerCase().split("-")[0] === currentLanguage,
|
||||
);
|
||||
if (languageMatch) return languageMatch;
|
||||
|
||||
return locales[0] || "";
|
||||
}
|
||||
|
||||
function applyTemplate(template: {
|
||||
subject: string;
|
||||
html: string;
|
||||
is_custom?: boolean;
|
||||
placeholders?: string[];
|
||||
}) {
|
||||
subject.value = template.subject;
|
||||
html.value = template.html;
|
||||
isCustomTemplate.value = template.is_custom === true;
|
||||
placeholders.value = template.placeholders || [];
|
||||
}
|
||||
|
||||
async function loadTemplate() {
|
||||
if (!selectedEvent.value || !selectedLocale.value) return;
|
||||
loadingTemplate.value = true;
|
||||
try {
|
||||
const template = await adminAPI.settings.getEmailTemplate(
|
||||
selectedEvent.value,
|
||||
selectedLocale.value,
|
||||
);
|
||||
applyTemplate(template);
|
||||
await refreshPreview();
|
||||
} catch (err: unknown) {
|
||||
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||
} finally {
|
||||
loadingTemplate.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function loadTemplateList() {
|
||||
loadingList.value = true;
|
||||
try {
|
||||
const response = await adminAPI.settings.getEmailTemplates();
|
||||
eventOptions.value = response.events.map(normalizeEventOption);
|
||||
localeOptions.value = response.locales;
|
||||
placeholders.value = response.placeholders || [];
|
||||
initializingSelection.value = true;
|
||||
selectedEvent.value = eventOptions.value[0]?.value || "";
|
||||
selectedLocale.value = selectInitialLocale(response.locales);
|
||||
await loadTemplate();
|
||||
initializingSelection.value = false;
|
||||
} catch (err: unknown) {
|
||||
initializingSelection.value = false;
|
||||
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||
} finally {
|
||||
loadingList.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function saveTemplate() {
|
||||
if (!canSave.value) {
|
||||
appStore.showError(t("admin.settings.emailTemplates.validationRequired"));
|
||||
return;
|
||||
}
|
||||
saving.value = true;
|
||||
try {
|
||||
const template = await adminAPI.settings.updateEmailTemplate(
|
||||
selectedEvent.value,
|
||||
selectedLocale.value,
|
||||
{
|
||||
subject: subject.value,
|
||||
html: html.value,
|
||||
},
|
||||
);
|
||||
applyTemplate(template);
|
||||
await refreshPreview();
|
||||
appStore.showSuccess(t("admin.settings.emailTemplates.saveSuccess"));
|
||||
} catch (err: unknown) {
|
||||
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||
} finally {
|
||||
saving.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function refreshPreview() {
|
||||
if (!canPreview.value) {
|
||||
previewSubject.value = "";
|
||||
previewHtml.value = "";
|
||||
return;
|
||||
}
|
||||
previewing.value = true;
|
||||
try {
|
||||
const preview = await adminAPI.settings.previewEmailTemplate({
|
||||
event: selectedEvent.value,
|
||||
locale: selectedLocale.value,
|
||||
subject: subject.value,
|
||||
html: html.value,
|
||||
});
|
||||
previewSubject.value = preview.subject;
|
||||
previewHtml.value = preview.html;
|
||||
} catch (err: unknown) {
|
||||
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||
} finally {
|
||||
previewing.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function restoreOfficial() {
|
||||
if (!selectedEvent.value || !selectedLocale.value) return;
|
||||
if (!window.confirm(t("admin.settings.emailTemplates.restoreConfirm"))) return;
|
||||
|
||||
restoring.value = true;
|
||||
try {
|
||||
const template = await adminAPI.settings.restoreOfficialEmailTemplate(
|
||||
selectedEvent.value,
|
||||
selectedLocale.value,
|
||||
);
|
||||
applyTemplate(template);
|
||||
await refreshPreview();
|
||||
appStore.showSuccess(t("admin.settings.emailTemplates.restoreSuccess"));
|
||||
} catch (err: unknown) {
|
||||
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||
} finally {
|
||||
restoring.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
async function copyPlaceholder(placeholder: string) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(placeholder);
|
||||
appStore.showSuccess(t("admin.settings.emailTemplates.placeholderCopied"));
|
||||
} catch {
|
||||
appStore.showError(t("common.error"));
|
||||
}
|
||||
}
|
||||
|
||||
watch([selectedEvent, selectedLocale], ([eventValue, localeValue], [oldEvent, oldLocale]) => {
|
||||
if (initializingSelection.value) return;
|
||||
if (!eventValue || !localeValue) return;
|
||||
if (eventValue === oldEvent && localeValue === oldLocale) return;
|
||||
void loadTemplate();
|
||||
});
|
||||
|
||||
onMounted(() => {
|
||||
void loadTemplateList();
|
||||
});
|
||||
</script>
|
||||
@ -13,6 +13,7 @@ export default defineConfig({
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'jsdom',
|
||||
setupFiles: ['./src/__tests__/setup.ts'],
|
||||
include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'],
|
||||
exclude: ['node_modules', 'dist'],
|
||||
coverage: {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user