chore: merge upstream Wei-Shaw/sub2api v0.1.128 — keep fork customizations
Upstream 新功能 (34 commits, ~main..origin/main): - feat(email): 通知邮件模板服务、模板编辑器、订阅/余额提醒邮件 - feat(notification): NotificationEmailService 注入到 Balance/Payment/Setting - feat(payment): 支付成功通知邮件 - feat(usage): 用户 API Key 用量页支持按日明细 - feat(openai-gateway): Codex OAuth 浏览器 UA 自动改写规避 Cloudflare 质询 - feat(admin): 邮件模板管理接口 - fix(auth): 停用/删除分组后阻断 API Key - fix(group): 修正分组账号可用计数口径 - fix(openai): /v1/responses respect force chat completions, images n 参数透传 - test(repository): AES Encryptor 单元测试 - chore: VERSION 0.1.128 冲突解决 (backend/cmd/server/wire_gen.go): - 引入 upstream 新 wire providers: notificationEmailService, ProvidePaymentService(10 args), ProvideAdminSettingHandler(8 args) - 保留 fork 独有依赖: rpmTokenBucketService (RPM 平滑), NewOpsHandler 3 参数版本 (requestEventBus, opsLogBroadcaster) - ProvideBalanceNotifyService 接受 4 参数 (含 notificationEmailService) 修复 session-id helper 设计 (claude_code_session_id.go): - 发现: TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata 在 OAuth + mimicClaudeCode=false 场景失败 - 重新审视设计原则: OAuth 凭证本身就是 Claude Code 客户端,可信任 metadata 派生 session_id;不应受 mimicClaudeCode 标志阻止 - 修复: metadata 派生只看 tokenType=="oauth";UUID 兜底仍需 oauth && mimic - 更新测试: OAuthNonMimicDerivesFromMetadata 取代原 IgnoresMetadata 所有 fork 独有功能保留: - Claude Code 2.1.145 mimicry bundle (上个 commit 引入) - RPM token bucket smoothing (commit 95814974) - Windsurf/Antigravity/Omniroute 定制 - claudemask/ 校验包 (upstream 已删除,我们仍在 gateway_service 中使用) 不在范围: - 不修复 baseline 既存的 2 个测试失败 (TestProxyImportData..., TestWindsurfTierAccessService_Snapshot_HappyPath) - 与 merge 无关
This commit is contained in:
commit
92433656f5
@ -1 +1 @@
|
|||||||
0.1.127
|
0.1.128
|
||||||
|
|||||||
@ -189,7 +189,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
channelRepository := repository.NewChannelRepository(db)
|
channelRepository := repository.NewChannelRepository(db)
|
||||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||||
|
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||||
@ -204,8 +205,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||||
registry := payment.ProvideRegistry()
|
registry := payment.ProvideRegistry()
|
||||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService, notificationEmailService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
|
settingHandler := handler.ProvideAdminSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService, notificationEmailService)
|
||||||
requestEventBus := service.NewRequestEventBus()
|
requestEventBus := service.NewRequestEventBus()
|
||||||
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
|
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
|
||||||
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
|
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
|
||||||
@ -253,7 +254,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo, notificationEmailService)
|
||||||
totpHandler := handler.NewTotpHandler(totpService)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
|
||||||
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
|
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
|
||||||
@ -274,7 +275,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||||
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
||||||
|
|||||||
@ -56,13 +56,14 @@ func firstNonEmpty(values ...string) string {
|
|||||||
|
|
||||||
// SettingHandler 系统设置处理器
|
// SettingHandler 系统设置处理器
|
||||||
type SettingHandler struct {
|
type SettingHandler struct {
|
||||||
settingService *service.SettingService
|
settingService *service.SettingService
|
||||||
emailService *service.EmailService
|
emailService *service.EmailService
|
||||||
turnstileService *service.TurnstileService
|
turnstileService *service.TurnstileService
|
||||||
opsService *service.OpsService
|
opsService *service.OpsService
|
||||||
paymentConfigService *service.PaymentConfigService
|
paymentConfigService *service.PaymentConfigService
|
||||||
paymentService *service.PaymentService
|
paymentService *service.PaymentService
|
||||||
userAttributeService *service.UserAttributeService
|
userAttributeService *service.UserAttributeService
|
||||||
|
notificationEmailService *service.NotificationEmailService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSettingHandler 创建系统设置处理器
|
// NewSettingHandler 创建系统设置处理器
|
||||||
@ -78,6 +79,12 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNotificationEmailService attaches the notification template service without changing
|
||||||
|
// the constructor signature used by existing unit tests.
|
||||||
|
func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) {
|
||||||
|
h.notificationEmailService = notificationEmailService
|
||||||
|
}
|
||||||
|
|
||||||
// GetSettings 获取所有系统设置
|
// GetSettings 获取所有系统设置
|
||||||
// GET /api/v1/admin/settings
|
// GET /api/v1/admin/settings
|
||||||
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||||
@ -247,6 +254,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
|
||||||
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
||||||
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
|
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
|
||||||
|
OpenAICodexUserAgent: settings.OpenAICodexUserAgent,
|
||||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||||
@ -563,6 +571,7 @@ type UpdateSettingsRequest struct {
|
|||||||
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||||
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
|
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
|
||||||
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
|
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
|
||||||
|
OpenAICodexUserAgent *string `json:"openai_codex_user_agent"`
|
||||||
|
|
||||||
// Payment visible method routing
|
// Payment visible method routing
|
||||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||||
@ -1404,6 +1413,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if req.OpenAICodexUserAgent != nil {
|
||||||
|
normalized := strings.TrimSpace(*req.OpenAICodexUserAgent)
|
||||||
|
req.OpenAICodexUserAgent = &normalized
|
||||||
|
// 仅做长度上限保护,不限制具体格式(运维需要可自由调整 codex 版本号)
|
||||||
|
if len(normalized) > 512 {
|
||||||
|
response.Error(c, http.StatusBadRequest, "openai_codex_user_agent must be at most 512 characters")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
|
// 交叉验证:如果同时设置了最低和最高版本号,最高版本号必须 >= 最低版本号
|
||||||
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
|
if req.MinClaudeCodeVersion != "" && req.MaxClaudeCodeVersion != "" {
|
||||||
@ -1597,6 +1615,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return previousSettings.AntigravityUserAgentVersion
|
return previousSettings.AntigravityUserAgentVersion
|
||||||
}(),
|
}(),
|
||||||
|
OpenAICodexUserAgent: func() string {
|
||||||
|
if req.OpenAICodexUserAgent != nil {
|
||||||
|
return *req.OpenAICodexUserAgent
|
||||||
|
}
|
||||||
|
return previousSettings.OpenAICodexUserAgent
|
||||||
|
}(),
|
||||||
PaymentVisibleMethodAlipaySource: func() string {
|
PaymentVisibleMethodAlipaySource: func() string {
|
||||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||||
@ -1956,6 +1980,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
|
||||||
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
|
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
|
||||||
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
|
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
|
||||||
|
OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent,
|
||||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||||
@ -2411,6 +2436,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion {
|
if before.AntigravityUserAgentVersion != after.AntigravityUserAgentVersion {
|
||||||
changed = append(changed, "antigravity_user_agent_version")
|
changed = append(changed, "antigravity_user_agent_version")
|
||||||
}
|
}
|
||||||
|
if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent {
|
||||||
|
changed = append(changed, "openai_codex_user_agent")
|
||||||
|
}
|
||||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||||
changed = append(changed, "payment_visible_method_alipay_source")
|
changed = append(changed, "payment_visible_method_alipay_source")
|
||||||
}
|
}
|
||||||
@ -3339,3 +3367,160 @@ func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key,
|
|||||||
}
|
}
|
||||||
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
|
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListEmailTemplates returns all editable notification email templates.
|
||||||
|
// GET /api/v1/admin/settings/email-templates
|
||||||
|
func (h *SettingHandler) ListEmailTemplates(c *gin.Context) {
|
||||||
|
if h.notificationEmailService == nil {
|
||||||
|
response.InternalError(c, "notification email service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
events := h.notificationEmailService.ListEventInfos()
|
||||||
|
templates, err := h.notificationEmailService.ListTemplates(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, dto.EmailTemplateListResponse{
|
||||||
|
Events: emailTemplateEventOptionsToDTO(events),
|
||||||
|
Locales: h.notificationEmailService.SupportedLocales(),
|
||||||
|
Templates: emailTemplateSummariesToDTO(templates),
|
||||||
|
Placeholders: emailTemplatePlaceholderUnion(events),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEmailTemplate returns one editable notification email template.
|
||||||
|
// GET /api/v1/admin/settings/email-templates/:event/:locale
|
||||||
|
func (h *SettingHandler) GetEmailTemplate(c *gin.Context) {
|
||||||
|
if h.notificationEmailService == nil {
|
||||||
|
response.InternalError(c, "notification email service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tmpl, err := h.notificationEmailService.GetTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"))
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, emailTemplateDetailToDTO(tmpl))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateEmailTemplate saves an override for one event/locale template.
|
||||||
|
// PUT /api/v1/admin/settings/email-templates/:event/:locale
|
||||||
|
func (h *SettingHandler) UpdateEmailTemplate(c *gin.Context) {
|
||||||
|
if h.notificationEmailService == nil {
|
||||||
|
response.InternalError(c, "notification email service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req dto.UpdateEmailTemplateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tmpl, err := h.notificationEmailService.UpdateTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"), req.Subject, req.HTML)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, emailTemplateDetailToDTO(tmpl))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreOfficialEmailTemplate removes an override and returns the built-in template.
|
||||||
|
// POST /api/v1/admin/settings/email-templates/:event/:locale/restore-official
|
||||||
|
func (h *SettingHandler) RestoreOfficialEmailTemplate(c *gin.Context) {
|
||||||
|
if h.notificationEmailService == nil {
|
||||||
|
response.InternalError(c, "notification email service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tmpl, err := h.notificationEmailService.RestoreOfficialTemplate(c.Request.Context(), c.Param("event"), c.Param("locale"))
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, emailTemplateDetailToDTO(tmpl))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreviewEmailTemplate renders a template with safe sample variables without saving it.
|
||||||
|
// POST /api/v1/admin/settings/email-templates/preview
|
||||||
|
func (h *SettingHandler) PreviewEmailTemplate(c *gin.Context) {
|
||||||
|
if h.notificationEmailService == nil {
|
||||||
|
response.InternalError(c, "notification email service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req dto.PreviewEmailTemplateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preview, err := h.notificationEmailService.PreviewTemplate(c.Request.Context(), service.NotificationEmailPreviewInput{
|
||||||
|
Event: req.Event,
|
||||||
|
Locale: req.Locale,
|
||||||
|
Subject: req.Subject,
|
||||||
|
HTML: req.HTML,
|
||||||
|
Variables: req.Variables,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, dto.EmailTemplatePreviewResponse{Subject: preview.Subject, HTML: preview.HTML})
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo) []dto.EmailTemplateEventOption {
|
||||||
|
items := make([]dto.EmailTemplateEventOption, 0, len(events))
|
||||||
|
for _, event := range events {
|
||||||
|
items = append(items, dto.EmailTemplateEventOption{
|
||||||
|
Value: event.Event,
|
||||||
|
Label: event.Label,
|
||||||
|
Description: event.Description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailTemplateSummariesToDTO(templates []service.NotificationEmailTemplate) []dto.EmailTemplateSummary {
|
||||||
|
items := make([]dto.EmailTemplateSummary, 0, len(templates))
|
||||||
|
for _, tmpl := range templates {
|
||||||
|
items = append(items, dto.EmailTemplateSummary{
|
||||||
|
Event: tmpl.Event,
|
||||||
|
Locale: tmpl.Locale,
|
||||||
|
Subject: tmpl.Subject,
|
||||||
|
IsCustom: tmpl.IsCustom,
|
||||||
|
UpdatedAt: emailTemplateUpdatedAt(tmpl),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailTemplateDetailToDTO(tmpl service.NotificationEmailTemplate) dto.EmailTemplateDetail {
|
||||||
|
return dto.EmailTemplateDetail{
|
||||||
|
Event: tmpl.Event,
|
||||||
|
Locale: tmpl.Locale,
|
||||||
|
Subject: tmpl.Subject,
|
||||||
|
HTML: tmpl.HTML,
|
||||||
|
IsCustom: tmpl.IsCustom,
|
||||||
|
UpdatedAt: emailTemplateUpdatedAt(tmpl),
|
||||||
|
Placeholders: tmpl.Placeholders,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailTemplateUpdatedAt(tmpl service.NotificationEmailTemplate) string {
|
||||||
|
if tmpl.UpdatedAt == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return tmpl.UpdatedAt.Format("2006-01-02T15:04:05Z07:00")
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) []string {
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
placeholders := make([]string, 0)
|
||||||
|
for _, event := range events {
|
||||||
|
for _, placeholder := range event.Placeholders {
|
||||||
|
if _, ok := seen[placeholder]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[placeholder] = struct{}{}
|
||||||
|
placeholders = append(placeholders, placeholder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return placeholders
|
||||||
|
}
|
||||||
|
|||||||
@ -203,7 +203,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
|
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email, c.GetHeader("Accept-Language"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@ -602,7 +602,7 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
|||||||
|
|
||||||
// Request password reset (async)
|
// Request password reset (async)
|
||||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL, c.GetHeader("Accept-Language")); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -545,7 +545,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
|
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email, c.GetHeader("Accept-Language"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -181,6 +181,7 @@ type SystemSettings struct {
|
|||||||
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
|
||||||
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
|
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
|
||||||
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
|
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
|
||||||
|
OpenAICodexUserAgent string `json:"openai_codex_user_agent"`
|
||||||
|
|
||||||
// Web Search Emulation
|
// Web Search Emulation
|
||||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||||
@ -377,6 +378,62 @@ type OpenAIFastPolicySettings struct {
|
|||||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmailTemplateEventOption describes an editable notification email event.
|
||||||
|
type EmailTemplateEventOption struct {
|
||||||
|
Value string `json:"value"`
|
||||||
|
Label string `json:"label,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailTemplateSummary is shown in the admin email template list.
|
||||||
|
type EmailTemplateSummary struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
Locale string `json:"locale"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
IsCustom bool `json:"is_custom,omitempty"`
|
||||||
|
UpdatedAt string `json:"updated_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailTemplateListResponse is returned by GET /admin/settings/email-templates.
|
||||||
|
type EmailTemplateListResponse struct {
|
||||||
|
Events []EmailTemplateEventOption `json:"events"`
|
||||||
|
Locales []string `json:"locales"`
|
||||||
|
Templates []EmailTemplateSummary `json:"templates,omitempty"`
|
||||||
|
Placeholders []string `json:"placeholders,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailTemplateDetail is returned for a specific event/locale template.
|
||||||
|
type EmailTemplateDetail struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
Locale string `json:"locale"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
HTML string `json:"html"`
|
||||||
|
IsCustom bool `json:"is_custom,omitempty"`
|
||||||
|
UpdatedAt string `json:"updated_at,omitempty"`
|
||||||
|
Placeholders []string `json:"placeholders,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateEmailTemplateRequest updates a template override.
|
||||||
|
type UpdateEmailTemplateRequest struct {
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
HTML string `json:"html"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreviewEmailTemplateRequest previews a template without saving it.
|
||||||
|
type PreviewEmailTemplateRequest struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
Locale string `json:"locale"`
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
HTML string `json:"html"`
|
||||||
|
Variables map[string]string `json:"variables,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailTemplatePreviewResponse is the rendered preview payload.
|
||||||
|
type EmailTemplatePreviewResponse struct {
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
HTML string `json:"html"`
|
||||||
|
}
|
||||||
|
|
||||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||||
// Returns empty slice on empty/invalid input.
|
// Returns empty slice on empty/invalid input.
|
||||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||||
|
|||||||
@ -1133,9 +1133,15 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
|||||||
|
|
||||||
// 解析可选的日期范围参数(用于 model_stats 查询)
|
// 解析可选的日期范围参数(用于 model_stats 查询)
|
||||||
startTime, endTime := h.parseUsageDateRange(c)
|
startTime, endTime := h.parseUsageDateRange(c)
|
||||||
|
days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", ""))
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Invalid days, allowed range is 1-90")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
|
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
|
||||||
usageData := h.buildUsageData(ctx, apiKey.ID)
|
usageData := h.buildUsageData(ctx, apiKey.ID)
|
||||||
|
dailyUsage := h.buildAPIKeyDailyUsage(c, subject.UserID, apiKey.ID, days)
|
||||||
|
|
||||||
// Best-effort: 获取模型统计
|
// Best-effort: 获取模型统计
|
||||||
var modelStats any
|
var modelStats any
|
||||||
@ -1149,11 +1155,11 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
|||||||
isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
|
isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
|
||||||
|
|
||||||
if isQuotaLimited {
|
if isQuotaLimited {
|
||||||
h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats)
|
h.usageQuotaLimited(c, ctx, apiKey, usageData, dailyUsage, modelStats)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats)
|
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, dailyUsage, modelStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围
|
// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围
|
||||||
@ -1211,8 +1217,20 @@ func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *GatewayHandler) buildAPIKeyDailyUsage(c *gin.Context, userID, apiKeyID int64, days int) any {
|
||||||
|
if h.usageService == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
startTime, endTime := apiKeyDailyUsageRange(days, c.Query("timezone"))
|
||||||
|
stats, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), userID, apiKeyID, startTime, endTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
// usageQuotaLimited 处理 quota_limited 模式的响应
|
// usageQuotaLimited 处理 quota_limited 模式的响应
|
||||||
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) {
|
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, dailyUsage any, modelStats any) {
|
||||||
resp := gin.H{
|
resp := gin.H{
|
||||||
"mode": "quota_limited",
|
"mode": "quota_limited",
|
||||||
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
|
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
|
||||||
@ -1294,6 +1312,9 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
|||||||
if usageData != nil {
|
if usageData != nil {
|
||||||
resp["usage"] = usageData
|
resp["usage"] = usageData
|
||||||
}
|
}
|
||||||
|
if dailyUsage != nil {
|
||||||
|
resp["daily_usage"] = dailyUsage
|
||||||
|
}
|
||||||
if modelStats != nil {
|
if modelStats != nil {
|
||||||
resp["model_stats"] = modelStats
|
resp["model_stats"] = modelStats
|
||||||
}
|
}
|
||||||
@ -1302,7 +1323,7 @@ func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
|
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
|
||||||
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) {
|
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, dailyUsage any, modelStats any) {
|
||||||
// 订阅模式
|
// 订阅模式
|
||||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||||
resp := gin.H{
|
resp := gin.H{
|
||||||
@ -1331,6 +1352,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context,
|
|||||||
if usageData != nil {
|
if usageData != nil {
|
||||||
resp["usage"] = usageData
|
resp["usage"] = usageData
|
||||||
}
|
}
|
||||||
|
if dailyUsage != nil {
|
||||||
|
resp["daily_usage"] = dailyUsage
|
||||||
|
}
|
||||||
if modelStats != nil {
|
if modelStats != nil {
|
||||||
resp["model_stats"] = modelStats
|
resp["model_stats"] = modelStats
|
||||||
}
|
}
|
||||||
@ -1356,6 +1380,9 @@ func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context,
|
|||||||
if usageData != nil {
|
if usageData != nil {
|
||||||
resp["usage"] = usageData
|
resp["usage"] = usageData
|
||||||
}
|
}
|
||||||
|
if dailyUsage != nil {
|
||||||
|
resp["daily_usage"] = dailyUsage
|
||||||
|
}
|
||||||
if modelStats != nil {
|
if modelStats != nil {
|
||||||
resp["model_stats"] = modelStats
|
resp["model_stats"] = modelStats
|
||||||
}
|
}
|
||||||
|
|||||||
@ -266,6 +266,7 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
|
|||||||
PaymentSource: req.PaymentSource,
|
PaymentSource: req.PaymentSource,
|
||||||
OrderType: req.OrderType,
|
OrderType: req.OrderType,
|
||||||
PlanID: req.PlanID,
|
PlanID: req.PlanID,
|
||||||
|
Locale: c.GetHeader("Accept-Language"),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"html"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@ -10,8 +14,9 @@ import (
|
|||||||
|
|
||||||
// SettingHandler 公开设置处理器(无需认证)
|
// SettingHandler 公开设置处理器(无需认证)
|
||||||
type SettingHandler struct {
|
type SettingHandler struct {
|
||||||
settingService *service.SettingService
|
settingService *service.SettingService
|
||||||
version string
|
notificationEmailService *service.NotificationEmailService
|
||||||
|
version string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSettingHandler 创建公开设置处理器
|
// NewSettingHandler 创建公开设置处理器
|
||||||
@ -22,6 +27,12 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNotificationEmailService attaches the public notification email service without
|
||||||
|
// changing the constructor signature used by existing tests.
|
||||||
|
func (h *SettingHandler) SetNotificationEmailService(notificationEmailService *service.NotificationEmailService) {
|
||||||
|
h.notificationEmailService = notificationEmailService
|
||||||
|
}
|
||||||
|
|
||||||
// GetPublicSettings 获取公开设置
|
// GetPublicSettings 获取公开设置
|
||||||
// GET /api/v1/settings/public
|
// GET /api/v1/settings/public
|
||||||
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||||
@ -90,6 +101,27 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnsubscribeNotificationEmail handles optional notification email opt-outs.
|
||||||
|
// GET /api/v1/settings/email-unsubscribe?token=...
|
||||||
|
func (h *SettingHandler) UnsubscribeNotificationEmail(c *gin.Context) {
|
||||||
|
if h.notificationEmailService == nil {
|
||||||
|
response.InternalError(c, "notification email service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := strings.TrimSpace(c.Query("token"))
|
||||||
|
if token == "" {
|
||||||
|
response.BadRequest(c, "token is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.notificationEmailService.Unsubscribe(c.Request.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body := "<!doctype html><html><head><meta charset=\"utf-8\"><title>Unsubscribed</title></head><body style=\"font-family:-apple-system,BlinkMacSystemFont,Segoe UI,sans-serif;padding:32px;\"><h1>Unsubscribed</h1><p>You have unsubscribed <strong>" + html.EscapeString(result.Email) + "</strong> from <strong>" + html.EscapeString(result.Event) + "</strong> emails.</p></body></html>"
|
||||||
|
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(body))
|
||||||
|
}
|
||||||
|
|
||||||
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
func publicLoginAgreementDocumentsToDTO(items []service.LoginAgreementDocument) []dto.LoginAgreementDocument {
|
||||||
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
result := make([]dto.LoginAgreementDocument, 0, len(items))
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
|
|||||||
@ -172,7 +172,7 @@ func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID, c.GetHeader("Accept-Language")); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -298,6 +298,29 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
|||||||
return startTime, endTime
|
return startTime, endTime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultAPIKeyDailyUsageDays = 30
|
||||||
|
maxAPIKeyDailyUsageDays = 90
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseAPIKeyDailyUsageDays(raw string) (int, bool) {
|
||||||
|
if strings.TrimSpace(raw) == "" {
|
||||||
|
return defaultAPIKeyDailyUsageDays, true
|
||||||
|
}
|
||||||
|
days, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || days <= 0 || days > maxAPIKeyDailyUsageDays {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return days, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func apiKeyDailyUsageRange(days int, userTZ string) (time.Time, time.Time) {
|
||||||
|
now := timezone.NowInUserLocation(userTZ)
|
||||||
|
startTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -(days-1)), userTZ)
|
||||||
|
endTime := timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
|
||||||
|
return startTime, endTime
|
||||||
|
}
|
||||||
|
|
||||||
// DashboardStats handles getting user dashboard statistics
|
// DashboardStats handles getting user dashboard statistics
|
||||||
// GET /api/v1/usage/dashboard/stats
|
// GET /api/v1/usage/dashboard/stats
|
||||||
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||||
@ -416,3 +439,55 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, gin.H{"stats": stats})
|
response.Success(c, gin.H{"stats": stats})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMyAPIKeyDailyUsage handles getting daily usage details for the current user's API key.
|
||||||
|
// GET /api/v1/user/api-keys/:id/usage/daily?days=30
|
||||||
|
func (h *UsageHandler) GetMyAPIKeyDailyUsage(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKeyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid API key ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
days, ok := parseAPIKeyDailyUsageDays(c.DefaultQuery("days", ""))
|
||||||
|
if !ok {
|
||||||
|
response.BadRequest(c, "Invalid days, allowed range is 1-90")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.apiKeyService == nil {
|
||||||
|
response.InternalError(c, "API key service is not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), apiKeyID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if apiKey.UserID != subject.UserID {
|
||||||
|
response.Forbidden(c, "Not authorized to access this API key's usage")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userTZ := c.Query("timezone")
|
||||||
|
startTime, endTime := apiKeyDailyUsageRange(days, userTZ)
|
||||||
|
items, err := h.usageService.GetAPIKeyDailyUsage(c.Request.Context(), subject.UserID, apiKeyID, startTime, endTime)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"items": items,
|
||||||
|
"days": days,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.AddDate(0, 0, -1).Format("2006-01-02"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
195
backend/internal/handler/usage_handler_daily_test.go
Normal file
195
backend/internal/handler/usage_handler_daily_test.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dailyUsageRepoStub struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
trend []usagestats.TrendDataPoint
|
||||||
|
|
||||||
|
called bool
|
||||||
|
startTime time.Time
|
||||||
|
endTime time.Time
|
||||||
|
granularity string
|
||||||
|
userID int64
|
||||||
|
apiKeyID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dailyUsageRepoStub) GetUsageTrendWithFilters(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
s.called = true
|
||||||
|
s.startTime = startTime
|
||||||
|
s.endTime = endTime
|
||||||
|
s.granularity = granularity
|
||||||
|
s.userID = userID
|
||||||
|
s.apiKeyID = apiKeyID
|
||||||
|
return s.trend, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dailyUsageAPIKeyRepoStub struct {
|
||||||
|
service.APIKeyRepository
|
||||||
|
keys map[int64]*service.APIKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dailyUsageAPIKeyRepoStub) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||||
|
key, ok := s.keys[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
clone := *key
|
||||||
|
return &clone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDailyUsageTestRouter(usageRepo *dailyUsageRepoStub, apiKeyRepo *dailyUsageAPIKeyRepoStub, userID int64) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
usageSvc := service.NewUsageService(usageRepo, nil, nil, nil)
|
||||||
|
apiKeySvc := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, nil)
|
||||||
|
handler := NewUsageHandler(usageSvc, apiKeySvc)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: userID})
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
router.GET("/user/api-keys/:id/usage/daily", handler.GetMyAPIKeyDailyUsage)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
type dailyUsageHandlerResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data struct {
|
||||||
|
Items []usagestats.APIKeyDailyUsagePoint `json:"items"`
|
||||||
|
Days int `json:"days"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMyAPIKeyDailyUsageRejectsCrossUserAccess(t *testing.T) {
|
||||||
|
usageRepo := &dailyUsageRepoStub{}
|
||||||
|
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||||
|
keys: map[int64]*service.APIKey{
|
||||||
|
7: {ID: 7, UserID: 99, Status: service.StatusAPIKeyActive},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=30", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
require.False(t, usageRepo.called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMyAPIKeyDailyUsageRejectsInvalidDays(t *testing.T) {
|
||||||
|
for _, path := range []string{
|
||||||
|
"/user/api-keys/7/usage/daily?days=0",
|
||||||
|
"/user/api-keys/7/usage/daily?days=91",
|
||||||
|
} {
|
||||||
|
t.Run(path, func(t *testing.T) {
|
||||||
|
usageRepo := &dailyUsageRepoStub{}
|
||||||
|
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||||
|
keys: map[int64]*service.APIKey{
|
||||||
|
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
require.False(t, usageRepo.called)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMyAPIKeyDailyUsageReturnsEmptyData(t *testing.T) {
|
||||||
|
usageRepo := &dailyUsageRepoStub{trend: []usagestats.TrendDataPoint{}}
|
||||||
|
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||||
|
keys: map[int64]*service.APIKey{
|
||||||
|
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var got dailyUsageHandlerResponse
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||||
|
require.Equal(t, 30, got.Data.Days)
|
||||||
|
require.Empty(t, got.Data.Items)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMyAPIKeyDailyUsageAggregatesByDayForOwnedKey(t *testing.T) {
|
||||||
|
usageRepo := &dailyUsageRepoStub{
|
||||||
|
trend: []usagestats.TrendDataPoint{
|
||||||
|
{
|
||||||
|
Date: "2026-05-19",
|
||||||
|
Requests: 3,
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheCreationTokens: 4,
|
||||||
|
CacheReadTokens: 6,
|
||||||
|
TotalTokens: 40,
|
||||||
|
Cost: 0.5,
|
||||||
|
ActualCost: 0.4,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
apiKeyRepo := &dailyUsageAPIKeyRepoStub{
|
||||||
|
keys: map[int64]*service.APIKey{
|
||||||
|
7: {ID: 7, UserID: 42, Status: service.StatusAPIKeyActive},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := newDailyUsageTestRouter(usageRepo, apiKeyRepo, 42)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/user/api-keys/7/usage/daily?days=7", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.True(t, usageRepo.called)
|
||||||
|
require.Equal(t, "day", usageRepo.granularity)
|
||||||
|
require.Equal(t, int64(42), usageRepo.userID)
|
||||||
|
require.Equal(t, int64(7), usageRepo.apiKeyID)
|
||||||
|
require.True(t, usageRepo.startTime.Before(usageRepo.endTime))
|
||||||
|
|
||||||
|
var got dailyUsageHandlerResponse
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||||
|
require.Equal(t, 7, got.Data.Days)
|
||||||
|
require.Len(t, got.Data.Items, 1)
|
||||||
|
require.Equal(t, usagestats.APIKeyDailyUsagePoint{
|
||||||
|
Date: "2026-05-19",
|
||||||
|
Requests: 3,
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheReadTokens: 6,
|
||||||
|
CacheWriteTokens: 4,
|
||||||
|
TotalTokens: 40,
|
||||||
|
Cost: 0.5,
|
||||||
|
ActualCost: 0.4,
|
||||||
|
}, got.Data.Items[0])
|
||||||
|
}
|
||||||
@ -335,7 +335,7 @@ func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
|
if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email, c.GetHeader("Accept-Language")); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -363,7 +363,7 @@ func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
|
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache, c.GetHeader("Accept-Language"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -90,8 +90,17 @@ func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||||
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
|
func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo, notificationEmailService *service.NotificationEmailService) *SettingHandler {
|
||||||
return NewSettingHandler(settingService, buildInfo.Version)
|
h := NewSettingHandler(settingService, buildInfo.Version)
|
||||||
|
h.SetNotificationEmailService(notificationEmailService)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideAdminSettingHandler creates admin.SettingHandler with notification template APIs.
|
||||||
|
func ProvideAdminSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService, notificationEmailService *service.NotificationEmailService) *admin.SettingHandler {
|
||||||
|
h := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
|
||||||
|
h.SetNotificationEmailService(notificationEmailService)
|
||||||
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideHandlers creates the Handlers struct
|
// ProvideHandlers creates the Handlers struct
|
||||||
@ -169,7 +178,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewProxyHandler,
|
admin.NewProxyHandler,
|
||||||
admin.NewRedeemHandler,
|
admin.NewRedeemHandler,
|
||||||
admin.NewPromoHandler,
|
admin.NewPromoHandler,
|
||||||
admin.NewSettingHandler,
|
ProvideAdminSettingHandler,
|
||||||
admin.NewOpsHandler,
|
admin.NewOpsHandler,
|
||||||
ProvideSystemHandler,
|
ProvideSystemHandler,
|
||||||
admin.NewSubscriptionHandler,
|
admin.NewSubscriptionHandler,
|
||||||
|
|||||||
@ -0,0 +1,719 @@
|
|||||||
|
package apicompat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResponsesToChatCompletionsRequest converts a Responses API request into a
|
||||||
|
// Chat Completions request for upstreams that only implement
|
||||||
|
// /v1/chat/completions.
|
||||||
|
func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("responses request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := responsesInputToChatMessages(req.Instructions, req.Input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ChatCompletionsRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Messages: messages,
|
||||||
|
MaxCompletionTokens: req.MaxOutputTokens,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: req.Stream,
|
||||||
|
ServiceTier: req.ServiceTier,
|
||||||
|
}
|
||||||
|
if req.Reasoning != nil {
|
||||||
|
out.ReasoningEffort = req.Reasoning.Effort
|
||||||
|
}
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
out.Tools = responsesToolsToChatTools(req.Tools)
|
||||||
|
}
|
||||||
|
if len(req.ToolChoice) > 0 {
|
||||||
|
out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) {
|
||||||
|
var messages []ChatMessage
|
||||||
|
if strings.TrimSpace(instructions) != "" {
|
||||||
|
content, _ := json.Marshal(instructions)
|
||||||
|
messages = append(messages, ChatMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
inputRaw = bytesTrimSpace(inputRaw)
|
||||||
|
if len(inputRaw) == 0 || string(inputRaw) == "null" {
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputText string
|
||||||
|
if err := json.Unmarshal(inputRaw, &inputText); err == nil {
|
||||||
|
content, _ := json.Marshal(inputText)
|
||||||
|
messages = append(messages, ChatMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawItems []json.RawMessage
|
||||||
|
if err := json.Unmarshal(inputRaw, &rawItems); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse responses input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, raw := range rawItems {
|
||||||
|
raw = bytesTrimSpace(raw)
|
||||||
|
if len(raw) == 0 || string(raw) == "null" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var item map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(raw, &item); err != nil {
|
||||||
|
var text string
|
||||||
|
if textErr := json.Unmarshal(raw, &text); textErr == nil {
|
||||||
|
content, _ := json.Marshal(text)
|
||||||
|
messages = append(messages, ChatMessage{Role: "user", Content: content})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("parse responses input item: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
role := rawString(item["role"])
|
||||||
|
itemType := rawString(item["type"])
|
||||||
|
switch itemType {
|
||||||
|
case "function_call":
|
||||||
|
arguments := rawString(item["arguments"])
|
||||||
|
if strings.TrimSpace(arguments) == "" {
|
||||||
|
arguments = "{}"
|
||||||
|
}
|
||||||
|
messages = append(messages, ChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ChatToolCall{{
|
||||||
|
ID: rawString(item["call_id"]),
|
||||||
|
Type: "function",
|
||||||
|
Function: ChatFunctionCall{
|
||||||
|
Name: rawString(item["name"]),
|
||||||
|
Arguments: arguments,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
case "function_call_output":
|
||||||
|
content, _ := json.Marshal(rawString(item["output"]))
|
||||||
|
messages = append(messages, ChatMessage{
|
||||||
|
Role: "tool",
|
||||||
|
ToolCallID: rawString(item["call_id"]),
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
case "input_text", "text":
|
||||||
|
content, _ := json.Marshal(rawString(item["text"]))
|
||||||
|
messages = append(messages, ChatMessage{Role: "user", Content: content})
|
||||||
|
continue
|
||||||
|
case "input_image":
|
||||||
|
content, err := chatContentFromSingleResponsesPart(itemType, item)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messages = append(messages, ChatMessage{Role: "user", Content: content})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
content := item["content"]
|
||||||
|
if len(bytesTrimSpace(content)) == 0 {
|
||||||
|
if text := rawString(item["text"]); text != "" {
|
||||||
|
content, _ = json.Marshal(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chatContent, err := responsesContentToChatContent(content, role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messages = append(messages, ChatMessage{
|
||||||
|
Role: role,
|
||||||
|
Content: chatContent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) {
|
||||||
|
raw = bytesTrimSpace(raw)
|
||||||
|
if len(raw) == 0 || string(raw) == "null" {
|
||||||
|
empty, _ := json.Marshal("")
|
||||||
|
return empty, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var text string
|
||||||
|
if err := json.Unmarshal(raw, &text); err == nil {
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawParts []json.RawMessage
|
||||||
|
if err := json.Unmarshal(raw, &rawParts); err == nil {
|
||||||
|
return responsesContentPartsToChatContent(rawParts, role)
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(raw, &obj); err == nil {
|
||||||
|
return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) {
|
||||||
|
var textParts []string
|
||||||
|
var chatParts []ChatContentPart
|
||||||
|
hasNonText := false
|
||||||
|
|
||||||
|
for _, rawPart := range rawParts {
|
||||||
|
var part map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(rawPart, &part); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
partType := rawString(part["type"])
|
||||||
|
switch partType {
|
||||||
|
case "input_text", "output_text", "text", "":
|
||||||
|
text := rawString(part["text"])
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
textParts = append(textParts, text)
|
||||||
|
chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text})
|
||||||
|
case "input_image", "image_url":
|
||||||
|
imageURL := rawString(part["image_url"])
|
||||||
|
if imageURL == "" {
|
||||||
|
imageURL = rawNestedString(part["image_url"], "url")
|
||||||
|
}
|
||||||
|
if imageURL == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hasNonText = true
|
||||||
|
chatParts = append(chatParts, ChatContentPart{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageURL: &ChatImageURL{URL: imageURL},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasNonText {
|
||||||
|
joined, _ := json.Marshal(strings.Join(textParts, "\n\n"))
|
||||||
|
return joined, nil
|
||||||
|
}
|
||||||
|
if role != "user" {
|
||||||
|
joined, _ := json.Marshal(strings.Join(textParts, "\n\n"))
|
||||||
|
return joined, nil
|
||||||
|
}
|
||||||
|
if len(chatParts) == 0 {
|
||||||
|
empty, _ := json.Marshal("")
|
||||||
|
return empty, nil
|
||||||
|
}
|
||||||
|
return json.Marshal(chatParts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) {
|
||||||
|
switch partType {
|
||||||
|
case "input_image", "image_url":
|
||||||
|
imageURL := rawString(part["image_url"])
|
||||||
|
if imageURL == "" {
|
||||||
|
imageURL = rawNestedString(part["image_url"], "url")
|
||||||
|
}
|
||||||
|
return json.Marshal([]ChatContentPart{{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageURL: &ChatImageURL{URL: imageURL},
|
||||||
|
}})
|
||||||
|
default:
|
||||||
|
return json.Marshal(rawString(part["text"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool {
|
||||||
|
out := make([]ChatTool, 0, len(tools))
|
||||||
|
for _, tool := range tools {
|
||||||
|
if tool.Type != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, ChatTool{
|
||||||
|
Type: "function",
|
||||||
|
Function: &ChatFunction{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: tool.Parameters,
|
||||||
|
Strict: tool.Strict,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage {
|
||||||
|
var choice map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(raw, &choice); err != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
if rawString(choice["type"]) != "function" {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
name := rawString(choice["name"])
|
||||||
|
if name == "" {
|
||||||
|
name = rawNestedString(choice["function"], "name")
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
out, err := json.Marshal(map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]string{
|
||||||
|
"name": name,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions
|
||||||
|
// response into a Responses API response.
|
||||||
|
func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse {
|
||||||
|
id := ""
|
||||||
|
if resp != nil {
|
||||||
|
id = resp.ID
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
id = generateResponsesID()
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &ResponsesResponse{
|
||||||
|
ID: id,
|
||||||
|
Object: "response",
|
||||||
|
Model: model,
|
||||||
|
Status: "completed",
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
out.Output = []ResponsesOutput{emptyResponsesMessageOutput()}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
if out.Model == "" {
|
||||||
|
out.Model = resp.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Choices) > 0 {
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
out.Output = chatMessageToResponsesOutput(choice.Message)
|
||||||
|
if choice.FinishReason == "length" {
|
||||||
|
out.Status = "incomplete"
|
||||||
|
out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out.Output) == 0 {
|
||||||
|
out.Output = []ResponsesOutput{emptyResponsesMessageOutput()}
|
||||||
|
}
|
||||||
|
if resp.Usage != nil {
|
||||||
|
out.Usage = ChatUsageToResponsesUsage(resp.Usage)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput {
|
||||||
|
var outputs []ResponsesOutput
|
||||||
|
if message.ReasoningContent != "" {
|
||||||
|
outputs = append(outputs, ResponsesOutput{
|
||||||
|
Type: "reasoning",
|
||||||
|
ID: generateItemID(),
|
||||||
|
Summary: []ResponsesSummary{{
|
||||||
|
Type: "summary_text",
|
||||||
|
Text: message.ReasoningContent,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
text := chatMessageContentText(message.Content)
|
||||||
|
if text != "" || len(message.ToolCalls) == 0 {
|
||||||
|
outputs = append(outputs, ResponsesOutput{
|
||||||
|
Type: "message",
|
||||||
|
ID: generateItemID(),
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []ResponsesContentPart{{
|
||||||
|
Type: "output_text",
|
||||||
|
Text: text,
|
||||||
|
}},
|
||||||
|
Status: "completed",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, toolCall := range message.ToolCalls {
|
||||||
|
arguments := toolCall.Function.Arguments
|
||||||
|
if strings.TrimSpace(arguments) == "" {
|
||||||
|
arguments = "{}"
|
||||||
|
}
|
||||||
|
outputs = append(outputs, ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
ID: generateItemID(),
|
||||||
|
CallID: toolCall.ID,
|
||||||
|
Name: toolCall.Function.Name,
|
||||||
|
Arguments: arguments,
|
||||||
|
Status: "completed",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
func emptyResponsesMessageOutput() ResponsesOutput {
|
||||||
|
return ResponsesOutput{
|
||||||
|
Type: "message",
|
||||||
|
ID: generateItemID(),
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: ""}},
|
||||||
|
Status: "completed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatMessageContentText(raw json.RawMessage) string {
|
||||||
|
raw = bytesTrimSpace(raw)
|
||||||
|
if len(raw) == 0 || string(raw) == "null" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var text string
|
||||||
|
if err := json.Unmarshal(raw, &text); err == nil {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
var parts []ChatContentPart
|
||||||
|
if err := json.Unmarshal(raw, &parts); err == nil {
|
||||||
|
var texts []string
|
||||||
|
for _, part := range parts {
|
||||||
|
if part.Type == "text" && part.Text != "" {
|
||||||
|
texts = append(texts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(texts, "\n\n")
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses
|
||||||
|
// usage shape.
|
||||||
|
func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage {
|
||||||
|
if usage == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &ResponsesUsage{
|
||||||
|
InputTokens: usage.PromptTokens,
|
||||||
|
OutputTokens: usage.CompletionTokens,
|
||||||
|
TotalTokens: usage.TotalTokens,
|
||||||
|
}
|
||||||
|
if out.TotalTokens == 0 {
|
||||||
|
out.TotalTokens = out.InputTokens + out.OutputTokens
|
||||||
|
}
|
||||||
|
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
|
||||||
|
out.InputTokensDetails = &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: usage.PromptTokensDetails.CachedTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsToResponsesStreamState tracks state while converting Chat
|
||||||
|
// Completions SSE chunks into Responses SSE events.
|
||||||
|
type ChatCompletionsToResponsesStreamState struct {
|
||||||
|
ResponseID string
|
||||||
|
Model string
|
||||||
|
Created int64
|
||||||
|
SequenceNumber int
|
||||||
|
CreatedSent bool
|
||||||
|
CompletedSent bool
|
||||||
|
|
||||||
|
MessageItemID string
|
||||||
|
Text strings.Builder
|
||||||
|
Reasoning strings.Builder
|
||||||
|
ToolCalls map[int]*ChatToolCall
|
||||||
|
|
||||||
|
FinishReason string
|
||||||
|
Usage *ResponsesUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChatCompletionsToResponsesStreamState returns an initialized stream state.
|
||||||
|
func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState {
|
||||||
|
return &ChatCompletionsToResponsesStreamState{
|
||||||
|
ResponseID: generateResponsesID(),
|
||||||
|
Model: model,
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
ToolCalls: make(map[int]*ChatToolCall),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream
|
||||||
|
// chunk into zero or more Responses stream events.
|
||||||
|
func ChatCompletionsChunkToResponsesEvents(
|
||||||
|
chunk *ChatCompletionsChunk,
|
||||||
|
state *ChatCompletionsToResponsesStreamState,
|
||||||
|
) []ResponsesStreamEvent {
|
||||||
|
if chunk == nil || state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if chunk.ID != "" {
|
||||||
|
state.ResponseID = chunk.ID
|
||||||
|
}
|
||||||
|
if state.Model == "" && chunk.Model != "" {
|
||||||
|
state.Model = chunk.Model
|
||||||
|
}
|
||||||
|
if chunk.Usage != nil {
|
||||||
|
state.Usage = ChatUsageToResponsesUsage(chunk.Usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []ResponsesStreamEvent
|
||||||
|
events = append(events, ensureChatToResponsesCreated(state)...)
|
||||||
|
|
||||||
|
for _, choice := range chunk.Choices {
|
||||||
|
if choice.Delta.Content != nil {
|
||||||
|
events = append(events, ensureChatToResponsesMessageItem(state)...)
|
||||||
|
_, _ = state.Text.WriteString(*choice.Delta.Content)
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Delta: *choice.Delta.Content,
|
||||||
|
ItemID: state.MessageItemID,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
if choice.Delta.ReasoningContent != nil {
|
||||||
|
_, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent)
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: 0,
|
||||||
|
SummaryIndex: 0,
|
||||||
|
Delta: *choice.Delta.ReasoningContent,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
for _, toolCall := range choice.Delta.ToolCalls {
|
||||||
|
idx := 0
|
||||||
|
if toolCall.Index != nil {
|
||||||
|
idx = *toolCall.Index
|
||||||
|
}
|
||||||
|
stored, ok := state.ToolCalls[idx]
|
||||||
|
if !ok {
|
||||||
|
copyCall := toolCall
|
||||||
|
if copyCall.ID == "" {
|
||||||
|
copyCall.ID = generateItemID()
|
||||||
|
}
|
||||||
|
copyCall.Type = "function"
|
||||||
|
state.ToolCalls[idx] = ©Call
|
||||||
|
stored = ©Call
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: idx + 1,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
ID: generateItemID(),
|
||||||
|
CallID: stored.ID,
|
||||||
|
Name: stored.Function.Name,
|
||||||
|
Status: "in_progress",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
if toolCall.ID != "" {
|
||||||
|
stored.ID = toolCall.ID
|
||||||
|
}
|
||||||
|
if toolCall.Function.Name != "" {
|
||||||
|
stored.Function.Name = toolCall.Function.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCall.Function.Arguments != "" {
|
||||||
|
stored.Function.Arguments += toolCall.Function.Arguments
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: idx + 1,
|
||||||
|
Delta: toolCall.Function.Arguments,
|
||||||
|
CallID: stored.ID,
|
||||||
|
Name: stored.Function.Name,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
||||||
|
state.FinishReason = *choice.FinishReason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeChatCompletionsResponsesStream emits terminal Responses events.
|
||||||
|
func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
|
||||||
|
if state == nil || state.CompletedSent {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var events []ResponsesStreamEvent
|
||||||
|
events = append(events, ensureChatToResponsesCreated(state)...)
|
||||||
|
if state.MessageItemID != "" {
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: 0,
|
||||||
|
ContentIndex: 0,
|
||||||
|
Text: state.Text.String(),
|
||||||
|
ItemID: state.MessageItemID,
|
||||||
|
}))
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: 0,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "message",
|
||||||
|
ID: state.MessageItemID,
|
||||||
|
Role: "assistant",
|
||||||
|
Status: "completed",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
status := "completed"
|
||||||
|
var incompleteDetails *ResponsesIncompleteDetails
|
||||||
|
if state.FinishReason == "length" {
|
||||||
|
status = "incomplete"
|
||||||
|
incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"}
|
||||||
|
}
|
||||||
|
|
||||||
|
state.CompletedSent = true
|
||||||
|
events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
ID: state.ResponseID,
|
||||||
|
Object: "response",
|
||||||
|
Model: state.Model,
|
||||||
|
Status: status,
|
||||||
|
Output: state.chatOutput(),
|
||||||
|
Usage: state.Usage,
|
||||||
|
IncompleteDetails: incompleteDetails,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
|
||||||
|
if state.CreatedSent {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.CreatedSent = true
|
||||||
|
return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
ID: state.ResponseID,
|
||||||
|
Object: "response",
|
||||||
|
Model: state.Model,
|
||||||
|
Status: "in_progress",
|
||||||
|
Output: []ResponsesOutput{},
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent {
|
||||||
|
if state.MessageItemID != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.MessageItemID = generateItemID()
|
||||||
|
return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{
|
||||||
|
OutputIndex: 0,
|
||||||
|
Item: &ResponsesOutput{
|
||||||
|
Type: "message",
|
||||||
|
ID: state.MessageItemID,
|
||||||
|
Role: "assistant",
|
||||||
|
Status: "in_progress",
|
||||||
|
},
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput {
|
||||||
|
var outputs []ResponsesOutput
|
||||||
|
if state.Reasoning.Len() > 0 {
|
||||||
|
outputs = append(outputs, ResponsesOutput{
|
||||||
|
Type: "reasoning",
|
||||||
|
ID: generateItemID(),
|
||||||
|
Summary: []ResponsesSummary{{
|
||||||
|
Type: "summary_text",
|
||||||
|
Text: state.Reasoning.String(),
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if state.MessageItemID != "" || len(state.ToolCalls) == 0 {
|
||||||
|
outputs = append(outputs, ResponsesOutput{
|
||||||
|
Type: "message",
|
||||||
|
ID: nonEmpty(state.MessageItemID, generateItemID()),
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []ResponsesContentPart{{
|
||||||
|
Type: "output_text",
|
||||||
|
Text: state.Text.String(),
|
||||||
|
}},
|
||||||
|
Status: "completed",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for i := 0; i < len(state.ToolCalls); i++ {
|
||||||
|
toolCall, ok := state.ToolCalls[i]
|
||||||
|
if !ok || toolCall == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
arguments := toolCall.Function.Arguments
|
||||||
|
if strings.TrimSpace(arguments) == "" {
|
||||||
|
arguments = "{}"
|
||||||
|
}
|
||||||
|
outputs = append(outputs, ResponsesOutput{
|
||||||
|
Type: "function_call",
|
||||||
|
ID: generateItemID(),
|
||||||
|
CallID: toolCall.ID,
|
||||||
|
Name: toolCall.Function.Name,
|
||||||
|
Arguments: arguments,
|
||||||
|
Status: "completed",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatToResponsesEvent(
|
||||||
|
state *ChatCompletionsToResponsesStreamState,
|
||||||
|
eventType string,
|
||||||
|
template *ResponsesStreamEvent,
|
||||||
|
) ResponsesStreamEvent {
|
||||||
|
seq := state.SequenceNumber
|
||||||
|
state.SequenceNumber++
|
||||||
|
evt := *template
|
||||||
|
evt.Type = eventType
|
||||||
|
evt.SequenceNumber = seq
|
||||||
|
return evt
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawString(raw json.RawMessage) string {
|
||||||
|
raw = bytesTrimSpace(raw)
|
||||||
|
if len(raw) == 0 || string(raw) == "null" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawNestedString(raw json.RawMessage, key string) string {
|
||||||
|
var obj map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(raw, &obj); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return rawString(obj[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytesTrimSpace(raw json.RawMessage) json.RawMessage {
|
||||||
|
return json.RawMessage(strings.TrimSpace(string(raw)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func nonEmpty(value, fallback string) string {
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
@ -30,6 +30,17 @@ var CodexOfficialClientOriginatorPrefixes = []string{
|
|||||||
"codex ",
|
"codex ",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsBrowserUserAgent 判断 User-Agent 是否来自浏览器(Chrome/Firefox/Safari/Edge/Opera 等)。
|
||||||
|
// 所有现代浏览器的 UA 均以 "Mozilla/" 作为前缀,CLI 工具(codex/claude/curl/postman/python-requests 等)不会。
|
||||||
|
// 该判定用于避免 Cloudflare 对浏览器型 UA 在 OpenAI 上游接口上触发 JS 质询。
|
||||||
|
func IsBrowserUserAgent(userAgent string) bool {
|
||||||
|
ua := strings.TrimSpace(userAgent)
|
||||||
|
if ua == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(strings.ToLower(ua), "mozilla/")
|
||||||
|
}
|
||||||
|
|
||||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||||
func IsCodexCLIRequest(userAgent string) bool {
|
func IsCodexCLIRequest(userAgent string) bool {
|
||||||
ua := normalizeCodexClientHeader(userAgent)
|
ua := normalizeCodexClientHeader(userAgent)
|
||||||
|
|||||||
@ -198,6 +198,19 @@ type APIKeyUsageTrendPoint struct {
|
|||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyDailyUsagePoint represents one day of usage for a single API key.
|
||||||
|
type APIKeyDailyUsagePoint struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
InputTokens int64 `json:"input_tokens"`
|
||||||
|
OutputTokens int64 `json:"output_tokens"`
|
||||||
|
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||||
|
CacheWriteTokens int64 `json:"cache_write_tokens"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
// UserDashboardStats 用户仪表盘统计
|
// UserDashboardStats 用户仪表盘统计
|
||||||
type UserDashboardStats struct {
|
type UserDashboardStats struct {
|
||||||
// API Key 统计
|
// API Key 统计
|
||||||
|
|||||||
219
backend/internal/repository/aes_encryptor_test.go
Normal file
219
backend/internal/repository/aes_encryptor_test.go
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ── 测试辅助 ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// aesHexKey 构造一个全填充为 b 的 n 字节密钥并以 hex 编码返回。
|
||||||
|
func aesHexKey(n int, b byte) string {
|
||||||
|
raw := make([]byte, n)
|
||||||
|
for i := range raw {
|
||||||
|
raw[i] = b
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// aesTestCfg 用给定 hex 密钥字符串构造最小 Config。
|
||||||
|
func aesTestCfg(keyHex string) *config.Config {
|
||||||
|
return &config.Config{
|
||||||
|
Totp: config.TotpConfig{EncryptionKey: keyHex},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// aesEncryptor 创建一个持有合法 32 字节密钥的加密器,测试失败时立即终止。
|
||||||
|
func aesEncryptor(t *testing.T) *AESEncryptor {
|
||||||
|
t.Helper()
|
||||||
|
enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x42)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, enc)
|
||||||
|
return enc.(*AESEncryptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── NewAESEncryptor ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNewAESEncryptor_ValidKey32Bytes(t *testing.T) {
|
||||||
|
enc, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0x01)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, enc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 16 / 24 字节密钥在 AES 体系内合法,但本实现仅接受 AES-256(32 字节)。
|
||||||
|
func TestNewAESEncryptor_WrongKeyLength(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keySize int
|
||||||
|
}{
|
||||||
|
{"16_bytes_AES128", 16},
|
||||||
|
{"24_bytes_AES192", 24},
|
||||||
|
{"1_byte", 1},
|
||||||
|
{"31_bytes", 31},
|
||||||
|
{"33_bytes", 33},
|
||||||
|
{"64_bytes", 64},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := NewAESEncryptor(aesTestCfg(aesHexKey(tt.keySize, 0x00)))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "32 bytes")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// "配置缺失"场景:空字符串与非法 hex 编码。
|
||||||
|
func TestNewAESEncryptor_MissingOrInvalidConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyHex string
|
||||||
|
wantContain string
|
||||||
|
}{
|
||||||
|
{"empty_key", "", "32 bytes"},
|
||||||
|
{"invalid_hex_odd_length", "abcde", "invalid totp encryption key"},
|
||||||
|
{"invalid_hex_chars", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", "invalid totp encryption key"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := NewAESEncryptor(aesTestCfg(tt.keyHex))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.wantContain)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 加解密往返(Roundtrip)───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestAESEncryptor_RoundTrip(t *testing.T) {
|
||||||
|
enc := aesEncryptor(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plaintext string
|
||||||
|
}{
|
||||||
|
{"ascii", "Hello, Sub2API!"},
|
||||||
|
{"chinese_multibyte", "你好,世界!这是多字节 UTF-8 文本。"},
|
||||||
|
{"empty_string", ""},
|
||||||
|
{"long_string_gt_1KB", strings.Repeat("x", 2048)},
|
||||||
|
{"special_chars", "!@#$%^&*()_+-=[]{}|;':\",./<>?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ct, err := enc.Encrypt(tt.plaintext)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, ct, "密文不应为空(即便明文为空字符串)")
|
||||||
|
|
||||||
|
got, err := enc.Decrypt(ct)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.plaintext, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── IV/Nonce 随机性 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestAESEncryptor_Encrypt_NonceRandomness(t *testing.T) {
|
||||||
|
enc := aesEncryptor(t)
|
||||||
|
const iterations = 30
|
||||||
|
plaintext := "same plaintext for every iteration"
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, iterations)
|
||||||
|
for i := 0; i < iterations; i++ {
|
||||||
|
ct, err := enc.Encrypt(plaintext)
|
||||||
|
require.NoError(t, err)
|
||||||
|
seen[ct] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 30 次加密相同明文,每次因随机 Nonce 应产生不同密文。
|
||||||
|
assert.Len(t, seen, iterations,
|
||||||
|
"每次加密应因随机 Nonce 产生唯一密文,共 %d 次", iterations)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Decrypt 错误路径 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestAESDecrypt_InvalidBase64(t *testing.T) {
|
||||||
|
enc := aesEncryptor(t)
|
||||||
|
_, err := enc.Decrypt("!!!not-valid-base64!!!")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "decode base64")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAESDecrypt_TooShort(t *testing.T) {
|
||||||
|
enc := aesEncryptor(t)
|
||||||
|
// GCM Nonce 为 12 字节;仅提供 2 字节,必然短于 NonceSize。
|
||||||
|
short := base64.StdEncoding.EncodeToString([]byte{0x01, 0x02})
|
||||||
|
_, err := enc.Decrypt(short)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAESDecrypt_TamperedCiphertext(t *testing.T) {
|
||||||
|
enc := aesEncryptor(t)
|
||||||
|
|
||||||
|
ct, err := enc.Encrypt("sensitive payload")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
raw, err := base64.StdEncoding.DecodeString(ct)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Nonce 占前 12 字节;翻转其后第一个字节(密文体)。
|
||||||
|
raw[12] ^= 0xFF
|
||||||
|
_, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw))
|
||||||
|
require.Error(t, err, "篡改密文体后解密应失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAESDecrypt_TamperedTag(t *testing.T) {
|
||||||
|
enc := aesEncryptor(t)
|
||||||
|
|
||||||
|
ct, err := enc.Encrypt("sensitive payload")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
raw, err := base64.StdEncoding.DecodeString(ct)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// GCM 认证标签占最后 16 字节;翻转最后一个字节。
|
||||||
|
raw[len(raw)-1] ^= 0xFF
|
||||||
|
_, err = enc.Decrypt(base64.StdEncoding.EncodeToString(raw))
|
||||||
|
require.Error(t, err, "篡改 GCM 标签后解密应失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 跨实例(Cross-instance)──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestAESEncryptor_CrossInstance_SameKey_CanDecrypt(t *testing.T) {
|
||||||
|
keyHex := aesHexKey(32, 0xDE)
|
||||||
|
|
||||||
|
enc1, err := NewAESEncryptor(aesTestCfg(keyHex))
|
||||||
|
require.NoError(t, err)
|
||||||
|
enc2, err := NewAESEncryptor(aesTestCfg(keyHex))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
plaintext := "cross-instance roundtrip"
|
||||||
|
ct, err := enc1.Encrypt(plaintext)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := enc2.Decrypt(ct)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext, got, "相同密钥构造的两个实例应可互相解密")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAESEncryptor_CrossInstance_DifferentKey_CannotDecrypt(t *testing.T) {
|
||||||
|
enc1, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xAA)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
enc2, err := NewAESEncryptor(aesTestCfg(aesHexKey(32, 0xBB)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ct, err := enc1.Encrypt("secret message")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = enc2.Decrypt(ct)
|
||||||
|
require.Error(t, err, "不同密钥的实例不应能解密对方的密文")
|
||||||
|
}
|
||||||
@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
|||||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
func TestGroupRepository_DeleteCascade_PreservesApiKeyGroupID(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tx := testEntTx(t)
|
tx := testEntTx(t)
|
||||||
entClient := tx.Client()
|
entClient := tx.Client()
|
||||||
@ -138,8 +138,10 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
|||||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||||
|
|
||||||
// API keys bound to the deleted group should have group_id cleared.
|
// API keys keep their group_id so auth can reject keys bound to a deleted group.
|
||||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, keyAfter.GroupID)
|
require.NotNil(t, keyAfter.GroupID)
|
||||||
|
require.Equal(t, targetGroup.ID, *keyAfter.GroupID)
|
||||||
|
require.Nil(t, keyAfter.Group)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
|
||||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@ -94,9 +93,13 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
total, active, _ := r.GetAccountCount(ctx, out.ID)
|
counts, err := r.loadAccountCounts(ctx, []int64{out.ID})
|
||||||
out.AccountCount = total
|
if err == nil {
|
||||||
out.ActiveAccountCount = active
|
c := counts[out.ID]
|
||||||
|
out.AccountCount = c.Total
|
||||||
|
out.ActiveAccountCount = c.Active
|
||||||
|
out.RateLimitedAccountCount = c.RateLimited
|
||||||
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -538,15 +541,12 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
|
|||||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
|
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
|
||||||
var rateLimited int64
|
var rateLimited int64
|
||||||
err = scanSingleRow(ctx, r.sql,
|
err = scanSingleRow(ctx, r.sql,
|
||||||
`SELECT COUNT(*),
|
fmt.Sprintf(`SELECT
|
||||||
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
|
COUNT(*) FILTER (WHERE a.deleted_at IS NULL),
|
||||||
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
COUNT(*) FILTER (WHERE %s),
|
||||||
a.rate_limit_reset_at > NOW() OR
|
COUNT(*) FILTER (WHERE %s)
|
||||||
a.overload_until > NOW() OR
|
|
||||||
a.temp_unschedulable_until > NOW()
|
|
||||||
))
|
|
||||||
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
|
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
|
||||||
WHERE ag.group_id = $1`,
|
WHERE ag.group_id = $1`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL),
|
||||||
[]any{groupID}, &total, &active, &rateLimited)
|
[]any{groupID}, &total, &active, &rateLimited)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -636,28 +636,18 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Clear group_id for api keys bound to this group.
|
// 2. Remove the group id from user_allowed_groups join table.
|
||||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
|
||||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
|
||||||
if _, err := txClient.APIKey.Update().
|
|
||||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
|
||||||
ClearGroupID().
|
|
||||||
Save(ctx); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Remove the group id from user_allowed_groups join table.
|
|
||||||
// Legacy users.allowed_groups 列已弃用,不再同步。
|
// Legacy users.allowed_groups 列已弃用,不再同步。
|
||||||
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
|
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Delete account_groups join rows.
|
// 3. Delete account_groups join rows.
|
||||||
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
|
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Soft-delete group itself.
|
// 4. Soft-delete group itself.
|
||||||
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
|
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -680,6 +670,28 @@ type groupAccountCounts struct {
|
|||||||
RateLimited int64
|
RateLimited int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 分组页的"可用"账号数必须与账号仓储的 ListSchedulableByGroupID 过滤口径一致。
|
||||||
|
groupAccountAvailableSQL = `a.deleted_at IS NULL
|
||||||
|
AND a.status = 'active'
|
||||||
|
AND a.schedulable = true
|
||||||
|
AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE)
|
||||||
|
AND (a.rate_limit_reset_at IS NULL OR a.rate_limit_reset_at <= NOW())
|
||||||
|
AND (a.overload_until IS NULL OR a.overload_until <= NOW())
|
||||||
|
AND (a.temp_unschedulable_until IS NULL OR a.temp_unschedulable_until <= NOW())`
|
||||||
|
|
||||||
|
// 这里沿用历史字段名 RateLimitedAccountCount,但统计的是会让账号暂时退出调度的时间窗口。
|
||||||
|
groupAccountTemporarilyLimitedSQL = `a.deleted_at IS NULL
|
||||||
|
AND a.status = 'active'
|
||||||
|
AND a.schedulable = true
|
||||||
|
AND (a.expires_at IS NULL OR a.expires_at > NOW() OR a.auto_pause_on_expired = FALSE)
|
||||||
|
AND (
|
||||||
|
a.rate_limit_reset_at > NOW() OR
|
||||||
|
a.overload_until > NOW() OR
|
||||||
|
a.temp_unschedulable_until > NOW()
|
||||||
|
)`
|
||||||
|
)
|
||||||
|
|
||||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
|
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
|
||||||
counts = make(map[int64]groupAccountCounts, len(groupIDs))
|
counts = make(map[int64]groupAccountCounts, len(groupIDs))
|
||||||
if len(groupIDs) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
@ -688,18 +700,14 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
|||||||
|
|
||||||
rows, err := r.sql.QueryContext(
|
rows, err := r.sql.QueryContext(
|
||||||
ctx,
|
ctx,
|
||||||
`SELECT ag.group_id,
|
fmt.Sprintf(`SELECT ag.group_id,
|
||||||
COUNT(*) AS total,
|
COUNT(*) FILTER (WHERE a.deleted_at IS NULL) AS total,
|
||||||
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
|
COUNT(*) FILTER (WHERE %s) AS active,
|
||||||
COUNT(*) FILTER (WHERE a.status = 'active' AND (
|
COUNT(*) FILTER (WHERE %s) AS rate_limited
|
||||||
a.rate_limit_reset_at > NOW() OR
|
|
||||||
a.overload_until > NOW() OR
|
|
||||||
a.temp_unschedulable_until > NOW()
|
|
||||||
)) AS rate_limited
|
|
||||||
FROM account_groups ag
|
FROM account_groups ag
|
||||||
JOIN accounts a ON a.id = ag.account_id
|
JOIN accounts a ON a.id = ag.account_id
|
||||||
WHERE ag.group_id = ANY($1)
|
WHERE ag.group_id = ANY($1)
|
||||||
GROUP BY ag.group_id`,
|
GROUP BY ag.group_id`, groupAccountAvailableSQL, groupAccountTemporarilyLimitedSQL),
|
||||||
pq.Array(groupIDs),
|
pq.Array(groupIDs),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -651,6 +651,164 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
|||||||
s.Require().Zero(count)
|
s.Require().Zero(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestListWithFilters_ActiveAccountCount_LessThanTotal 验证 ActiveAccountCount 正确区分可用与不可用账号。
|
||||||
|
// 当分组内存在 disabled 或 schedulable=false 的账号时,ActiveAccountCount 必须小于 AccountCount,
|
||||||
|
// 且与 GetAccountCount 返回的 active 值一致。
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_ActiveAccountCount_LessThanTotal() {
|
||||||
|
g := &service.Group{
|
||||||
|
Name: "g-mixed-status",
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
RateMultiplier: 1.0,
|
||||||
|
IsExclusive: false,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
SubscriptionType: service.SubscriptionTypeStandard,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||||
|
|
||||||
|
insertAccount := func(name, status string, schedulable bool) int64 {
|
||||||
|
var id int64
|
||||||
|
s.Require().NoError(scanSingleRow(
|
||||||
|
s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type, status, schedulable) VALUES ($1, $2, $3, $4, $5) RETURNING id",
|
||||||
|
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status, schedulable},
|
||||||
|
&id,
|
||||||
|
))
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
link := func(accountID int64, priority int) {
|
||||||
|
_, err := s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
accountID, g.ID, priority)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// account 1: active + schedulable → counts toward both total and active
|
||||||
|
link(insertAccount("acc-active-sched", service.StatusActive, true), 1)
|
||||||
|
// account 2: disabled → counts toward total only
|
||||||
|
link(insertAccount("acc-disabled", service.StatusDisabled, true), 2)
|
||||||
|
// account 3: active + not schedulable → counts toward total only
|
||||||
|
link(insertAccount("acc-unschedulable", service.StatusActive, false), 3)
|
||||||
|
|
||||||
|
// --- ListWithFilters path ---
|
||||||
|
isExclusive := false
|
||||||
|
groups, _, err := s.repo.ListWithFilters(s.ctx,
|
||||||
|
pagination.PaginationParams{Page: 1, PageSize: 100},
|
||||||
|
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
var found *service.Group
|
||||||
|
for i := range groups {
|
||||||
|
if groups[i].ID == g.ID {
|
||||||
|
found = &groups[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
|
||||||
|
s.Assert().Equal(int64(3), found.AccountCount, "AccountCount must count all 3 accounts")
|
||||||
|
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must count only the active+schedulable account")
|
||||||
|
|
||||||
|
// --- GetAccountCount must return identical values ---
|
||||||
|
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
|
||||||
|
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListWithFilters_RateLimitedAccountCount 验证临时受限账号不会计入可用账号数。
|
||||||
|
// rate_limit / overload / temp_unschedulable 都会让账号退出当前调度池,
|
||||||
|
// 因此 ActiveAccountCount 必须与真实调度查询口径一致。
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_RateLimitedAccountCount() {
|
||||||
|
g := &service.Group{
|
||||||
|
Name: "g-rate-limited",
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
RateMultiplier: 1.0,
|
||||||
|
IsExclusive: false,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
SubscriptionType: service.SubscriptionTypeStandard,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||||
|
|
||||||
|
var normalID int64
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||||
|
[]any{"acc-normal", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
|
&normalID))
|
||||||
|
|
||||||
|
var rateLimitedID int64
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type, rate_limit_reset_at) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
|
||||||
|
[]any{"acc-rate-limited", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
|
&rateLimitedID))
|
||||||
|
|
||||||
|
var overloadedID int64
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type, overload_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
|
||||||
|
[]any{"acc-overloaded", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
|
&overloadedID))
|
||||||
|
|
||||||
|
var tempUnschedulableID int64
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type, temp_unschedulable_until) VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour') RETURNING id",
|
||||||
|
[]any{"acc-temp-unschedulable", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
|
&tempUnschedulableID))
|
||||||
|
|
||||||
|
var expiredID int64
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type, expires_at, auto_pause_on_expired) VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', TRUE) RETURNING id",
|
||||||
|
[]any{"acc-expired", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||||
|
&expiredID))
|
||||||
|
|
||||||
|
_, err := s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
normalID, g.ID, 1)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
rateLimitedID, g.ID, 2)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
overloadedID, g.ID, 3)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
tempUnschedulableID, g.ID, 4)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
_, err = s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
expiredID, g.ID, 5)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
isExclusive := false
|
||||||
|
groups, _, err := s.repo.ListWithFilters(s.ctx,
|
||||||
|
pagination.PaginationParams{Page: 1, PageSize: 100},
|
||||||
|
service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
var found *service.Group
|
||||||
|
for i := range groups {
|
||||||
|
if groups[i].ID == g.ID {
|
||||||
|
found = &groups[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Require().NotNil(found, "created group must appear in ListWithFilters result")
|
||||||
|
s.Assert().Equal(int64(5), found.AccountCount, "AccountCount must include all linked accounts")
|
||||||
|
s.Assert().Equal(int64(1), found.ActiveAccountCount, "ActiveAccountCount must include only currently schedulable accounts")
|
||||||
|
s.Assert().Equal(int64(3), found.RateLimitedAccountCount, "RateLimitedAccountCount must include temporarily limited accounts")
|
||||||
|
|
||||||
|
total, active, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Assert().Equal(found.AccountCount, total, "GetAccountCount total must match ListWithFilters AccountCount")
|
||||||
|
s.Assert().Equal(found.ActiveAccountCount, active, "GetAccountCount active must match ListWithFilters ActiveAccountCount")
|
||||||
|
|
||||||
|
detail, err := s.repo.GetByID(s.ctx, g.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Assert().Equal(found.AccountCount, detail.AccountCount, "GetByID AccountCount must match ListWithFilters")
|
||||||
|
s.Assert().Equal(found.ActiveAccountCount, detail.ActiveAccountCount, "GetByID ActiveAccountCount must match ListWithFilters")
|
||||||
|
s.Assert().Equal(found.RateLimitedAccountCount, detail.RateLimitedAccountCount, "GetByID RateLimitedAccountCount must match ListWithFilters")
|
||||||
|
}
|
||||||
|
|
||||||
// --- DeleteAccountGroupsByGroupID ---
|
// --- DeleteAccountGroupsByGroupID ---
|
||||||
|
|
||||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||||
|
|||||||
@ -7,6 +7,67 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestListWithAccountCountSort_AttachesActiveCount 验证通过 account_count 排序时,
|
||||||
|
// ActiveAccountCount 与 AccountCount 都被正确附加到返回结果中,
|
||||||
|
// 且排序基于 total 账号数而非 active 账号数。
|
||||||
|
func (s *GroupRepoSuite) TestListWithAccountCountSort_AttachesActiveCount() {
|
||||||
|
// Group A: 2 total, 1 active (1 disabled account)
|
||||||
|
gA := &service.Group{Name: "sort-count-a", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard}
|
||||||
|
// Group B: 1 total, 1 active
|
||||||
|
gB := &service.Group{Name: "sort-count-b", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, gA))
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, gB))
|
||||||
|
|
||||||
|
insertAccount := func(name, status string) int64 {
|
||||||
|
var id int64
|
||||||
|
s.Require().NoError(scanSingleRow(s.ctx, s.tx,
|
||||||
|
"INSERT INTO accounts (name, platform, type, status) VALUES ($1, $2, $3, $4) RETURNING id",
|
||||||
|
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth, status},
|
||||||
|
&id))
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
link := func(accountID, groupID int64, priority int) {
|
||||||
|
_, err := s.tx.ExecContext(s.ctx,
|
||||||
|
"INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())",
|
||||||
|
accountID, groupID, priority)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// gA: 1 active + 1 disabled → total=2, active=1
|
||||||
|
link(insertAccount("sa-active", service.StatusActive), gA.ID, 1)
|
||||||
|
link(insertAccount("sa-disabled", service.StatusDisabled), gA.ID, 2)
|
||||||
|
// gB: 1 active → total=1, active=1
|
||||||
|
link(insertAccount("sb-active", service.StatusActive), gB.ID, 1)
|
||||||
|
|
||||||
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||||
|
Page: 1, PageSize: 100, SortBy: "account_count", SortOrder: "desc",
|
||||||
|
}, service.PlatformAnthropic, service.StatusActive, "", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
byID := make(map[int64]service.Group, len(groups))
|
||||||
|
for _, g := range groups {
|
||||||
|
byID[g.ID] = g
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Require().Contains(byID, gA.ID, "gA must appear in results")
|
||||||
|
s.Require().Contains(byID, gB.ID, "gB must appear in results")
|
||||||
|
|
||||||
|
cA := byID[gA.ID]
|
||||||
|
s.Assert().Equal(int64(2), cA.AccountCount, "gA AccountCount must be 2")
|
||||||
|
s.Assert().Equal(int64(1), cA.ActiveAccountCount, "gA ActiveAccountCount must be 1")
|
||||||
|
|
||||||
|
cB := byID[gB.ID]
|
||||||
|
s.Assert().Equal(int64(1), cB.AccountCount, "gB AccountCount must be 1")
|
||||||
|
s.Assert().Equal(int64(1), cB.ActiveAccountCount, "gB ActiveAccountCount must be 1")
|
||||||
|
|
||||||
|
// Sort is by total (not active): gA (total=2) must rank higher than gB (total=1) in desc order
|
||||||
|
indexByID := make(map[int64]int, len(groups))
|
||||||
|
for i, g := range groups {
|
||||||
|
indexByID[g.ID] = i
|
||||||
|
}
|
||||||
|
s.Assert().Less(indexByID[gA.ID], indexByID[gB.ID], "gA (total=2) must rank above gB (total=1) with account_count desc")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
|
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
|
||||||
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
|
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
|
||||||
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
|
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
|
||||||
|
|||||||
@ -833,6 +833,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"payment_visible_method_alipay_enabled": true,
|
"payment_visible_method_alipay_enabled": true,
|
||||||
"payment_visible_method_wxpay_enabled": false,
|
"payment_visible_method_wxpay_enabled": false,
|
||||||
"openai_advanced_scheduler_enabled": true,
|
"openai_advanced_scheduler_enabled": true,
|
||||||
|
"openai_codex_user_agent": "",
|
||||||
"openai_fast_policy_settings": {
|
"openai_fast_policy_settings": {
|
||||||
"rules": []
|
"rules": []
|
||||||
},
|
},
|
||||||
@ -1058,6 +1059,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"payment_visible_method_alipay_enabled": false,
|
"payment_visible_method_alipay_enabled": false,
|
||||||
"payment_visible_method_wxpay_enabled": false,
|
"payment_visible_method_wxpay_enabled": false,
|
||||||
"openai_advanced_scheduler_enabled": false,
|
"openai_advanced_scheduler_enabled": false,
|
||||||
|
"openai_codex_user_agent": "",
|
||||||
"openai_fast_policy_settings": {
|
"openai_fast_policy_settings": {
|
||||||
"rules": []
|
"rules": []
|
||||||
},
|
},
|
||||||
|
|||||||
@ -109,6 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if abortIfAPIKeyGroupUnavailable(c, apiKey) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// ── 4. SimpleMode → early return ─────────────────────────────
|
// ── 4. SimpleMode → early return ─────────────────────────────
|
||||||
|
|
||||||
@ -251,3 +254,26 @@ func setGroupContext(c *gin.Context, group *service.Group) {
|
|||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func abortIfAPIKeyGroupUnavailable(c *gin.Context, apiKey *service.APIKey) bool {
|
||||||
|
code, message, ok := validateAPIKeyGroupAvailable(apiKey)
|
||||||
|
if ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
AbortWithError(c, 403, code, message)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAPIKeyGroupAvailable(apiKey *service.APIKey) (string, string, bool) {
|
||||||
|
if apiKey == nil || apiKey.GroupID == nil {
|
||||||
|
return "", "", true
|
||||||
|
}
|
||||||
|
group := apiKey.Group
|
||||||
|
if group == nil || strings.EqualFold(group.Status, "deleted") {
|
||||||
|
return "GROUP_DELETED", "API Key 所属分组已删除", false
|
||||||
|
}
|
||||||
|
if !group.IsActive() {
|
||||||
|
return "GROUP_DISABLED", "API Key 所属分组已停用", false
|
||||||
|
}
|
||||||
|
return "", "", true
|
||||||
|
}
|
||||||
|
|||||||
@ -54,6 +54,10 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
|||||||
abortWithGoogleError(c, 401, "User account is not active")
|
abortWithGoogleError(c, 401, "User account is not active")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if _, message, ok := validateAPIKeyGroupAvailable(apiKey); !ok {
|
||||||
|
abortWithGoogleError(c, 403, message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 简易模式:跳过余额和订阅检查
|
// 简易模式:跳过余额和订阅检查
|
||||||
if cfg.RunMode == config.RunModeSimple {
|
if cfg.RunMode == config.RunModeSimple {
|
||||||
|
|||||||
@ -300,6 +300,104 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
groupID := int64(101)
|
||||||
|
user := &service.User{
|
||||||
|
ID: 7,
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Balance: 10,
|
||||||
|
Concurrency: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
group *service.Group
|
||||||
|
wantStatus int
|
||||||
|
wantCode string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "active group passes",
|
||||||
|
group: &service.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Name: "active",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled group is forbidden",
|
||||||
|
group: &service.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Name: "disabled",
|
||||||
|
Status: service.StatusDisabled,
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
wantCode: "GROUP_DISABLED",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted status group is forbidden",
|
||||||
|
group: &service.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Name: "deleted",
|
||||||
|
Status: "deleted",
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
Hydrated: true,
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
wantCode: "GROUP_DELETED",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing group edge is forbidden",
|
||||||
|
group: nil,
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
wantCode: "GROUP_DELETED",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 100,
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Key: "test-key",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
User: user,
|
||||||
|
Group: tt.group,
|
||||||
|
}
|
||||||
|
apiKeyRepo := &stubApiKeyRepo{
|
||||||
|
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
if key != apiKey.Key {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
clone := *apiKey
|
||||||
|
return &clone, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||||
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
req.Header.Set("x-api-key", apiKey.Key)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, tt.wantStatus, w.Code)
|
||||||
|
if tt.wantCode != "" {
|
||||||
|
require.Contains(t, w.Body.String(), tt.wantCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -422,6 +422,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
|
||||||
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
|
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
|
||||||
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
|
||||||
|
adminSettings.GET("/email-templates", h.Admin.Setting.ListEmailTemplates)
|
||||||
|
adminSettings.POST("/email-template-preview", h.Admin.Setting.PreviewEmailTemplate)
|
||||||
|
adminSettings.GET("/email-templates/:event/:locale", h.Admin.Setting.GetEmailTemplate)
|
||||||
|
adminSettings.PUT("/email-templates/:event/:locale", h.Admin.Setting.UpdateEmailTemplate)
|
||||||
|
adminSettings.POST("/email-templates/:event/:locale/restore-official", h.Admin.Setting.RestoreOfficialEmailTemplate)
|
||||||
// Admin API Key 管理
|
// Admin API Key 管理
|
||||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||||
|
|||||||
@ -214,6 +214,7 @@ func RegisterAuthRoutes(
|
|||||||
settings := v1.Group("/settings")
|
settings := v1.Group("/settings")
|
||||||
{
|
{
|
||||||
settings.GET("/public", h.Setting.GetPublicSettings)
|
settings.GET("/public", h.Setting.GetPublicSettings)
|
||||||
|
settings.GET("/email-unsubscribe", h.Setting.UnsubscribeNotificationEmail)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 需要认证的当前用户信息
|
// 需要认证的当前用户信息
|
||||||
|
|||||||
@ -31,6 +31,7 @@ func RegisterUserRoutes(
|
|||||||
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
||||||
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
|
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
|
||||||
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
||||||
|
user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage)
|
||||||
|
|
||||||
// 通知邮箱管理
|
// 通知邮箱管理
|
||||||
notifyEmail := user.Group("/notify-email")
|
notifyEmail := user.Group("/notify-email")
|
||||||
|
|||||||
@ -244,6 +244,21 @@ func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSor
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type deleteGroupAPIKeyRepoStub struct {
|
||||||
|
apiKeyRepoStubForGroupUpdate
|
||||||
|
keys []string
|
||||||
|
listErr error
|
||||||
|
listGroupIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *deleteGroupAPIKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
s.listGroupIDs = append(s.listGroupIDs, groupID)
|
||||||
|
if s.listErr != nil {
|
||||||
|
return nil, s.listErr
|
||||||
|
}
|
||||||
|
return s.keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
type proxyRepoStub struct {
|
type proxyRepoStub struct {
|
||||||
deleteErr error
|
deleteErr error
|
||||||
countErr error
|
countErr error
|
||||||
@ -500,6 +515,23 @@ func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) {
|
|||||||
}, calls)
|
}, calls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminService_DeleteGroup_InvalidatesAuthCacheForBoundKeys(t *testing.T) {
|
||||||
|
repo := &groupRepoStub{}
|
||||||
|
apiKeyRepo := &deleteGroupAPIKeyRepoStub{keys: []string{"k1", "k2"}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
groupRepo: repo,
|
||||||
|
apiKeyRepo: apiKeyRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.DeleteGroup(context.Background(), 5)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{5}, repo.deleteCalls)
|
||||||
|
require.Equal(t, []int64{5}, apiKeyRepo.listGroupIDs)
|
||||||
|
require.Equal(t, []string{"k1", "k2"}, invalidator.keys)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
|
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
|
||||||
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
|
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
|
||||||
svc := &adminServiceImpl{groupRepo: repo}
|
svc := &adminServiceImpl{groupRepo: repo}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/dgraph-io/ristretto"
|
"github.com/dgraph-io/ristretto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs
|
const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks
|
||||||
|
|
||||||
type apiKeyAuthCacheConfig struct {
|
type apiKeyAuthCacheConfig struct {
|
||||||
l1Size int
|
l1Size int
|
||||||
|
|||||||
@ -94,7 +94,7 @@ func (s *AuthService) BindEmailIdentity(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
|
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
|
||||||
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
|
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string, locale ...string) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return ErrServiceUnavailable
|
return ErrServiceUnavailable
|
||||||
}
|
}
|
||||||
@ -128,7 +128,7 @@ func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int6
|
|||||||
if s.settingService != nil {
|
if s.settingService != nil {
|
||||||
siteName = s.settingService.GetSiteName(ctx)
|
siteName = s.settingService.GetSiteName(ctx)
|
||||||
}
|
}
|
||||||
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
|
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName, firstEmailLocale(locale))
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeEmailForIdentityBinding(email string) (string, error) {
|
func normalizeEmailForIdentityBinding(email string) (string, error) {
|
||||||
|
|||||||
@ -28,7 +28,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
|
|||||||
|
|
||||||
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
||||||
// account-creation flows without relying on the public registration gate.
|
// account-creation flows without relying on the public registration gate.
|
||||||
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) {
|
||||||
email = strings.TrimSpace(strings.ToLower(email))
|
email = strings.TrimSpace(strings.ToLower(email))
|
||||||
if email == "" {
|
if email == "" {
|
||||||
return nil, ErrEmailVerifyRequired
|
return nil, ErrEmailVerifyRequired
|
||||||
@ -47,7 +47,7 @@ func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email stri
|
|||||||
if s.settingService != nil {
|
if s.settingService != nil {
|
||||||
siteName = s.settingService.GetSiteName(ctx)
|
siteName = s.settingService.GetSiteName(ctx)
|
||||||
}
|
}
|
||||||
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
|
if err := s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &SendVerifyCodeResult{
|
return &SendVerifyCodeResult{
|
||||||
|
|||||||
@ -273,7 +273,7 @@ type SendVerifyCodeResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
func (s *AuthService) SendVerifyCode(ctx context.Context, email string, locale ...string) error {
|
||||||
// 检查是否开放注册(默认关闭)
|
// 检查是否开放注册(默认关闭)
|
||||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
return ErrRegDisabled
|
return ErrRegDisabled
|
||||||
@ -307,11 +307,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
|||||||
siteName = s.settingService.GetSiteName(ctx)
|
siteName = s.settingService.GetSiteName(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.emailService.SendVerifyCode(ctx, email, siteName)
|
return s.emailService.SendVerifyCode(ctx, email, siteName, firstEmailLocale(locale))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
|
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
|
||||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string, locale ...string) (*SendVerifyCodeResult, error) {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email)
|
logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||||
|
|
||||||
// 检查是否开放注册(默认关闭)
|
// 检查是否开放注册(默认关闭)
|
||||||
@ -352,7 +352,7 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
|||||||
|
|
||||||
// 异步发送
|
// 异步发送
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email)
|
logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email)
|
||||||
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
|
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName, firstEmailLocale(locale)); err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err)
|
||||||
return nil, fmt.Errorf("enqueue verify code: %w", err)
|
return nil, fmt.Errorf("enqueue verify code: %w", err)
|
||||||
}
|
}
|
||||||
@ -1251,7 +1251,7 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB
|
|||||||
|
|
||||||
// RequestPasswordReset 请求密码重置(同步发送)
|
// RequestPasswordReset 请求密码重置(同步发送)
|
||||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
|
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string, locale ...string) error {
|
||||||
if !s.IsPasswordResetEnabled(ctx) {
|
if !s.IsPasswordResetEnabled(ctx) {
|
||||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||||
}
|
}
|
||||||
@ -1264,7 +1264,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
|
|||||||
return nil // Silent success to prevent enumeration
|
return nil // Silent success to prevent enumeration
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||||
return nil // Silent success to prevent enumeration
|
return nil // Silent success to prevent enumeration
|
||||||
}
|
}
|
||||||
@ -1275,7 +1275,7 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB
|
|||||||
|
|
||||||
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
||||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||||
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
|
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string, locale ...string) error {
|
||||||
if !s.IsPasswordResetEnabled(ctx) {
|
if !s.IsPasswordResetEnabled(ctx) {
|
||||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||||
}
|
}
|
||||||
@ -1288,7 +1288,7 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron
|
|||||||
return nil // Silent success to prevent enumeration
|
return nil // Silent success to prevent enumeration
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||||
return nil // Silent success to prevent enumeration
|
return nil // Silent success to prevent enumeration
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,9 +39,10 @@ type AccountQuotaReader interface {
|
|||||||
|
|
||||||
// BalanceNotifyService handles balance and quota threshold notifications.
|
// BalanceNotifyService handles balance and quota threshold notifications.
|
||||||
type BalanceNotifyService struct {
|
type BalanceNotifyService struct {
|
||||||
emailService *EmailService
|
emailService *EmailService
|
||||||
settingRepo SettingRepository
|
settingRepo SettingRepository
|
||||||
accountRepo AccountQuotaReader
|
accountRepo AccountQuotaReader
|
||||||
|
notificationEmailService *NotificationEmailService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBalanceNotifyService creates a new BalanceNotifyService.
|
// NewBalanceNotifyService creates a new BalanceNotifyService.
|
||||||
@ -53,6 +54,10 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BalanceNotifyService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||||
|
s.notificationEmailService = notificationEmailService
|
||||||
|
}
|
||||||
|
|
||||||
// resolveBalanceThreshold returns the effective balance threshold.
|
// resolveBalanceThreshold returns the effective balance threshold.
|
||||||
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
|
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
|
||||||
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
|
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
|
||||||
@ -125,7 +130,7 @@ func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user
|
|||||||
slog.Error("panic in balance notification", "recover", r)
|
slog.Error("panic in balance notification", "recover", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
|
s.sendBalanceLowEmails(recipients, user.ID, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,11 +347,44 @@ func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sendBalanceLowEmails sends balance low notification to all recipients.
|
// sendBalanceLowEmails sends balance low notification to all recipients.
|
||||||
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
|
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userID int64, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
|
||||||
displayName := userName
|
displayName := userName
|
||||||
if displayName == "" {
|
if displayName == "" {
|
||||||
displayName = userEmail
|
displayName = userEmail
|
||||||
}
|
}
|
||||||
|
if s.notificationEmailService != nil {
|
||||||
|
fallbackRecipients := make([]string, 0, len(recipients))
|
||||||
|
for _, to := range recipients {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||||
|
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventBalanceLow,
|
||||||
|
RecipientEmail: to,
|
||||||
|
RecipientName: displayName,
|
||||||
|
UserID: userID,
|
||||||
|
SourceType: "balance_low",
|
||||||
|
SourceID: firstNonEmpty(strconv.FormatInt(userID, 10), userEmail),
|
||||||
|
ReminderKey: time.Now().UTC().Format("2006-01-02"),
|
||||||
|
Variables: map[string]string{
|
||||||
|
"current_balance": fmt.Sprintf("%.2f", balance),
|
||||||
|
"threshold": fmt.Sprintf("%.2f", threshold),
|
||||||
|
"recharge_url": rechargeURL,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
if shouldFallbackNotificationEmail(err) {
|
||||||
|
slog.Warn("template balance low notification failed; falling back to built-in body", "to", to, "err", err.Error())
|
||||||
|
fallbackRecipients = append(fallbackRecipients, to)
|
||||||
|
} else {
|
||||||
|
slog.Warn("template balance low notification delivery failed; not sending fallback to avoid duplicates", "to", to, "err", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(fallbackRecipients) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recipients = fallbackRecipients
|
||||||
|
}
|
||||||
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
|
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
|
||||||
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
|
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
|
||||||
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
|
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
|
||||||
@ -369,6 +407,44 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
|
|||||||
remaining = 0
|
remaining = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.notificationEmailService != nil {
|
||||||
|
fallbackRecipients := make([]string, 0, len(adminEmails))
|
||||||
|
for _, to := range adminEmails {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||||
|
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventAccountQuotaAlert,
|
||||||
|
RecipientEmail: to,
|
||||||
|
RecipientName: emailRecipientName(to),
|
||||||
|
SourceType: "account_quota",
|
||||||
|
SourceID: fmt.Sprintf("%d-%s", accountID, dim.name),
|
||||||
|
ReminderKey: time.Now().UTC().Format("2006-01-02"),
|
||||||
|
Variables: map[string]string{
|
||||||
|
"account_id": strconv.FormatInt(accountID, 10),
|
||||||
|
"account_name": accountName,
|
||||||
|
"platform": platform,
|
||||||
|
"quota_dimension": dimLabel,
|
||||||
|
"quota_used": fmt.Sprintf("%.2f", used),
|
||||||
|
"quota_limit": fmt.Sprintf("%.2f", dim.limit),
|
||||||
|
"quota_remaining": fmt.Sprintf("%.2f", remaining),
|
||||||
|
"quota_threshold": thresholdDisplay,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
if shouldFallbackNotificationEmail(err) {
|
||||||
|
slog.Warn("template account quota alert failed; falling back to built-in body", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error())
|
||||||
|
fallbackRecipients = append(fallbackRecipients, to)
|
||||||
|
} else {
|
||||||
|
slog.Warn("template account quota alert delivery failed; not sending fallback to avoid duplicates", "to", to, "account_id", accountID, "dimension", dim.name, "err", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(fallbackRecipients) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
adminEmails = fallbackRecipients
|
||||||
|
}
|
||||||
|
|
||||||
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
|
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
|
||||||
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
|
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
|
||||||
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
|
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
|
||||||
|
|||||||
@ -15,16 +15,19 @@ import (
|
|||||||
//
|
//
|
||||||
// 行为按 tokenType / mimicClaudeCode 分两条路径:
|
// 行为按 tokenType / mimicClaudeCode 分两条路径:
|
||||||
//
|
//
|
||||||
// OAuth mimic 路径 (tokenType == "oauth" && mimicClaudeCode):
|
// OAuth 路径 (tokenType == "oauth"):
|
||||||
// 1. body 中 metadata.user_id 派生的 SessionID 是合法 UUID → canonicalize 写入
|
// OAuth 账号本身就是真实 Claude Code 客户端的凭证,可以信任 body 中的
|
||||||
// 2. 请求 header 中已有合法 UUID → canonicalize 保留
|
// metadata.user_id 派生 session id。
|
||||||
// 3. 否则 → 兜底生成 UUID
|
// 1. metadata.user_id 派生 SessionID 是合法 UUID → canonical 写入
|
||||||
|
// 2. header 已有合法 UUID → canonical 保留
|
||||||
|
// 3. mimicClaudeCode == true → 兜底生成新 UUID
|
||||||
|
// (mimicClaudeCode == false 且无 metadata 时不强制注入)
|
||||||
//
|
//
|
||||||
// API key 透传 / 非 mimic 路径:
|
// API key 透传路径 (tokenType != "oauth"):
|
||||||
// - 不从 body 合成 header(避免污染客户端原始语义)
|
// - 不从 body metadata 派生 header(避免污染客户端原始语义)
|
||||||
// - 但若客户端在 header 中传入了 X-Claude-Code-Session-Id:
|
// - 若客户端在 header 中传入 X-Claude-Code-Session-Id:
|
||||||
// 合法 UUID → canonicalize 保留
|
// 合法 UUID → canonical 保留
|
||||||
// 非法值 → 删除(不向上游转发恶意值,符合 UUID 校验承诺)
|
// 非法值 → 删除(不向上游转发恶意值)
|
||||||
// - 不兜底生成
|
// - 不兜底生成
|
||||||
//
|
//
|
||||||
// 安全说明:metadata.user_id 由客户端控制,ParseMetadataUserID 的正则仅约束字符集,
|
// 安全说明:metadata.user_id 由客户端控制,ParseMetadataUserID 的正则仅约束字符集,
|
||||||
@ -37,10 +40,10 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string,
|
|||||||
req.Header = make(http.Header)
|
req.Header = make(http.Header)
|
||||||
}
|
}
|
||||||
|
|
||||||
isOAuthMimic := tokenType == "oauth" && mimicClaudeCode
|
isOAuth := tokenType == "oauth"
|
||||||
|
|
||||||
// OAuth mimic 路径:从 metadata 派生(仅在 mimic 场景写 header)。
|
// OAuth 路径:从 metadata 派生(OAuth 凭证可信任)。
|
||||||
if isOAuthMimic {
|
if isOAuth {
|
||||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||||
if id, err := uuid.Parse(parsed.SessionID); err == nil {
|
if id, err := uuid.Parse(parsed.SessionID); err == nil {
|
||||||
@ -65,9 +68,9 @@ func ensureClaudeCodeSessionID(req *http.Request, body []byte, tokenType string,
|
|||||||
req.Header.Del("X-Claude-Code-Session-Id")
|
req.Header.Del("X-Claude-Code-Session-Id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuth mimic 兜底生成(仅 mimic 场景;API key 不污染)。
|
// OAuth mimic 兜底生成(仅 mimic 场景;API key/非 mimic 不污染)。
|
||||||
// uuid.NewString() 走 crypto/rand。
|
// uuid.NewString() 走 crypto/rand。
|
||||||
if isOAuthMimic {
|
if isOAuth && mimicClaudeCode {
|
||||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", uuid.NewString())
|
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", uuid.NewString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -136,15 +136,17 @@ func TestEnsureClaudeCodeSessionID_APIKeyIgnoresMetadata(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuth 但非 mimic 模式也不应该从 metadata 派生 header。
|
// OAuth 路径即使 mimic=false 也应该从 metadata 派生 header:
|
||||||
func TestEnsureClaudeCodeSessionID_OAuthNonMimicIgnoresMetadata(t *testing.T) {
|
// OAuth 凭证本身就是 Claude Code 类型账号,metadata.user_id 可信任。
|
||||||
|
// 这与 API key 路径不同(API key 是任意第三方调用方)。
|
||||||
|
func TestEnsureClaudeCodeSessionID_OAuthNonMimicDerivesFromMetadata(t *testing.T) {
|
||||||
req := newReq(t)
|
req := newReq(t)
|
||||||
body := []byte(`{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"` + testValidUUID + `\"}"}}`)
|
body := []byte(`{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"` + testValidUUID + `\"}"}}`)
|
||||||
ensureClaudeCodeSessionID(req, body, "oauth", false)
|
ensureClaudeCodeSessionID(req, body, "oauth", false)
|
||||||
|
|
||||||
got := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id")
|
got := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id")
|
||||||
if got != "" {
|
if got != testValidUUID {
|
||||||
t.Fatalf("Non-mimic OAuth must NOT derive session-id from metadata, got %q", got)
|
t.Fatalf("OAuth must derive session-id from metadata regardless of mimic flag, got %q want %q", got, testValidUUID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1463,6 +1463,24 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context,
|
|||||||
|
|
||||||
func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
|
func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
|
||||||
siteName := s.siteName(ctx)
|
siteName := s.siteName(ctx)
|
||||||
|
if s.emailService.notificationEmailService != nil {
|
||||||
|
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventContentModerationViolation,
|
||||||
|
RecipientEmail: log.UserEmail,
|
||||||
|
RecipientName: emailRecipientName(log.UserEmail),
|
||||||
|
UserID: contentModerationEmailUserID(log),
|
||||||
|
SourceType: "content_moderation",
|
||||||
|
SourceID: contentModerationEmailSourceID(log),
|
||||||
|
Variables: contentModerationEmailVariables(log, cfg),
|
||||||
|
}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
if !shouldFallbackNotificationEmail(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
slog.Warn("template content moderation violation email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName))
|
subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName))
|
||||||
body := buildContentModerationViolationEmailBody(siteName, log, cfg)
|
body := buildContentModerationViolationEmailBody(siteName, log, cfg)
|
||||||
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
|
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
|
||||||
@ -1470,11 +1488,71 @@ func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *
|
|||||||
|
|
||||||
func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
|
func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error {
|
||||||
siteName := s.siteName(ctx)
|
siteName := s.siteName(ctx)
|
||||||
|
if s.emailService.notificationEmailService != nil {
|
||||||
|
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventContentModerationDisabled,
|
||||||
|
RecipientEmail: log.UserEmail,
|
||||||
|
RecipientName: emailRecipientName(log.UserEmail),
|
||||||
|
UserID: contentModerationEmailUserID(log),
|
||||||
|
SourceType: "content_moderation",
|
||||||
|
SourceID: contentModerationEmailSourceID(log),
|
||||||
|
Variables: contentModerationEmailVariables(log, cfg),
|
||||||
|
}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
if !shouldFallbackNotificationEmail(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
slog.Warn("template content moderation disabled email failed; falling back to built-in body", "log_id", log.ID, "recipient_hash", notificationEmailHash(log.UserEmail), "err", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName))
|
subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName))
|
||||||
body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg)
|
body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg)
|
||||||
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
|
return s.emailService.SendEmail(ctx, log.UserEmail, subject, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func contentModerationEmailUserID(log *ContentModerationLog) int64 {
|
||||||
|
if log == nil || log.UserID == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *log.UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentModerationEmailSourceID(log *ContentModerationLog) string {
|
||||||
|
if log == nil || log.ID <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", log.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentModerationEmailVariables(log *ContentModerationLog, cfg *ContentModerationConfig) map[string]string {
|
||||||
|
variables := map[string]string{
|
||||||
|
"triggered_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"group_name": "-",
|
||||||
|
"moderation_category": "-",
|
||||||
|
"moderation_score": "0.000",
|
||||||
|
"violation_count": "0",
|
||||||
|
"ban_threshold": "0",
|
||||||
|
}
|
||||||
|
if log != nil {
|
||||||
|
if !log.CreatedAt.IsZero() {
|
||||||
|
variables["triggered_at"] = log.CreatedAt.UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(log.GroupName) != "" {
|
||||||
|
variables["group_name"] = strings.TrimSpace(log.GroupName)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(log.HighestCategory) != "" {
|
||||||
|
variables["moderation_category"] = strings.TrimSpace(log.HighestCategory)
|
||||||
|
}
|
||||||
|
variables["moderation_score"] = fmt.Sprintf("%.3f", log.HighestScore)
|
||||||
|
variables["violation_count"] = fmt.Sprintf("%d", log.ViolationCount)
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
variables["ban_threshold"] = fmt.Sprintf("%d", cfg.BanThreshold)
|
||||||
|
}
|
||||||
|
return variables
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ContentModerationService) siteName(ctx context.Context) string {
|
func (s *ContentModerationService) siteName(ctx context.Context) string {
|
||||||
if s == nil || s.settingRepo == nil {
|
if s == nil || s.settingRepo == nil {
|
||||||
return "Sub2API"
|
return "Sub2API"
|
||||||
|
|||||||
@ -401,6 +401,10 @@ const (
|
|||||||
SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control"
|
SettingKeyRewriteMessageCacheControl = "rewrite_message_cache_control"
|
||||||
// SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值)
|
// SettingKeyAntigravityUserAgentVersion Antigravity 上游 User-Agent 版本号(空值使用环境变量/默认值)
|
||||||
SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version"
|
SettingKeyAntigravityUserAgentVersion = "antigravity_user_agent_version"
|
||||||
|
// SettingKeyOpenAICodexUserAgent OpenAI Codex 完整 User-Agent(空值使用内置默认)
|
||||||
|
// 当客户端 UA 被识别为浏览器(Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值,
|
||||||
|
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
|
||||||
|
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
|
||||||
|
|
||||||
// Balance Low Notification
|
// Balance Low Notification
|
||||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||||
|
|||||||
@ -21,6 +21,7 @@ type EmailTask struct {
|
|||||||
SiteName string
|
SiteName string
|
||||||
TaskType string // "verify_code" or "password_reset"
|
TaskType string // "verify_code" or "password_reset"
|
||||||
ResetURL string // Only used for password_reset task type
|
ResetURL string // Only used for password_reset task type
|
||||||
|
Locale string // Optional Accept-Language locale hint
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmailQueueService 异步邮件队列服务
|
// EmailQueueService 异步邮件队列服务
|
||||||
@ -82,13 +83,13 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
|||||||
|
|
||||||
switch task.TaskType {
|
switch task.TaskType {
|
||||||
case TaskTypeVerifyCode:
|
case TaskTypeVerifyCode:
|
||||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName, task.Locale); err != nil {
|
||||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||||
} else {
|
} else {
|
||||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||||
}
|
}
|
||||||
case TaskTypePasswordReset:
|
case TaskTypePasswordReset:
|
||||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL, task.Locale); err != nil {
|
||||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||||
} else {
|
} else {
|
||||||
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||||
@ -99,11 +100,12 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EnqueueVerifyCode 将验证码发送任务加入队列
|
// EnqueueVerifyCode 将验证码发送任务加入队列
|
||||||
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string, locale ...string) error {
|
||||||
task := EmailTask{
|
task := EmailTask{
|
||||||
Email: email,
|
Email: email,
|
||||||
SiteName: siteName,
|
SiteName: siteName,
|
||||||
TaskType: TaskTypeVerifyCode,
|
TaskType: TaskTypeVerifyCode,
|
||||||
|
Locale: firstEmailLocale(locale),
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -116,12 +118,13 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
||||||
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
|
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string, locale ...string) error {
|
||||||
task := EmailTask{
|
task := EmailTask{
|
||||||
Email: email,
|
Email: email,
|
||||||
SiteName: siteName,
|
SiteName: siteName,
|
||||||
TaskType: TaskTypePasswordReset,
|
TaskType: TaskTypePasswordReset,
|
||||||
ResetURL: resetURL,
|
ResetURL: resetURL,
|
||||||
|
Locale: firstEmailLocale(locale),
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@ -94,8 +94,9 @@ type SMTPConfig struct {
|
|||||||
|
|
||||||
// EmailService 邮件服务
|
// EmailService 邮件服务
|
||||||
type EmailService struct {
|
type EmailService struct {
|
||||||
settingRepo SettingRepository
|
settingRepo SettingRepository
|
||||||
cache EmailCache
|
cache EmailCache
|
||||||
|
notificationEmailService *NotificationEmailService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEmailService 创建邮件服务实例
|
// NewEmailService 创建邮件服务实例
|
||||||
@ -106,6 +107,28 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *EmailService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||||
|
s.notificationEmailService = notificationEmailService
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstEmailLocale(locales []string) string {
|
||||||
|
if len(locales) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(locales[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailRecipientName(email string) string {
|
||||||
|
trimmed := strings.TrimSpace(email)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if at := strings.Index(trimmed, "@"); at > 0 {
|
||||||
|
return trimmed[:at]
|
||||||
|
}
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
|
||||||
// GetSMTPConfig 从数据库获取SMTP配置
|
// GetSMTPConfig 从数据库获取SMTP配置
|
||||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||||
keys := []string{
|
keys := []string{
|
||||||
@ -301,7 +324,7 @@ func (s *EmailService) GenerateVerifyCode() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendVerifyCode 发送验证码邮件
|
// SendVerifyCode 发送验证码邮件
|
||||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string, locale ...string) error {
|
||||||
// 检查是否在冷却期内
|
// 检查是否在冷却期内
|
||||||
existing, err := s.cache.GetVerificationCode(ctx, email)
|
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||||
if err == nil && existing != nil {
|
if err == nil && existing != nil {
|
||||||
@ -327,6 +350,26 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
|||||||
return fmt.Errorf("save verify code: %w", err)
|
return fmt.Errorf("save verify code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.notificationEmailService != nil {
|
||||||
|
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventAuthVerifyCode,
|
||||||
|
Locale: firstEmailLocale(locale),
|
||||||
|
RecipientEmail: email,
|
||||||
|
RecipientName: emailRecipientName(email),
|
||||||
|
Variables: map[string]string{
|
||||||
|
"verification_code": code,
|
||||||
|
"expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !shouldFallbackNotificationEmail(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
slog.Warn("failed to send templated verification email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 构建邮件内容
|
// 构建邮件内容
|
||||||
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
|
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
|
||||||
body := s.buildVerifyCodeEmailBody(code, siteName)
|
body := s.buildVerifyCodeEmailBody(code, siteName)
|
||||||
@ -469,7 +512,7 @@ func (s *EmailService) GeneratePasswordResetToken() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendPasswordResetEmail sends a password reset email with a reset link
|
// SendPasswordResetEmail sends a password reset email with a reset link
|
||||||
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
|
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string, locale ...string) error {
|
||||||
var token string
|
var token string
|
||||||
var needSaveToken bool
|
var needSaveToken bool
|
||||||
|
|
||||||
@ -502,6 +545,26 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
|
|||||||
// Build full reset URL with URL-encoded token and email
|
// Build full reset URL with URL-encoded token and email
|
||||||
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
||||||
|
|
||||||
|
if s.notificationEmailService != nil {
|
||||||
|
err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventAuthPasswordReset,
|
||||||
|
Locale: firstEmailLocale(locale),
|
||||||
|
RecipientEmail: email,
|
||||||
|
RecipientName: emailRecipientName(email),
|
||||||
|
Variables: map[string]string{
|
||||||
|
"reset_url": fullResetURL,
|
||||||
|
"expires_in_minutes": strconv.Itoa(int(passwordResetTokenTTL / time.Minute)),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !shouldFallbackNotificationEmail(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
slog.Warn("failed to send templated password reset email, falling back to legacy template", "recipient_hash", notificationEmailHash(email), "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Build email content
|
// Build email content
|
||||||
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
||||||
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
||||||
@ -516,7 +579,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
|
|||||||
|
|
||||||
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
||||||
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
||||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string, locale ...string) error {
|
||||||
// Check email cooldown to prevent email bombing
|
// Check email cooldown to prevent email bombing
|
||||||
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||||
slog.Info("password reset email skipped due to cooldown", "email", email)
|
slog.Info("password reset email skipped due to cooldown", "email", email)
|
||||||
@ -524,7 +587,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send email using core method
|
// Send email using core method
|
||||||
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL, firstEmailLocale(locale)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
1347
backend/internal/service/notification_email_service.go
Normal file
1347
backend/internal/service/notification_email_service.go
Normal file
File diff suppressed because it is too large
Load Diff
571
backend/internal/service/notification_email_service_test.go
Normal file
571
backend/internal/service/notification_email_service_test.go
Normal file
@ -0,0 +1,571 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNotificationEmailPreviewEscapesHTMLAndSanitizesSubject(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||||
|
|
||||||
|
preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
|
||||||
|
Event: NotificationEmailEventBalanceLow,
|
||||||
|
Locale: "en-US,en;q=0.9",
|
||||||
|
Subject: "Low balance for {{recipient_name}}\r\nInjected",
|
||||||
|
HTML: `<p>{{recipient_name}}</p><a href="{{recharge_url}}">Recharge</a>`,
|
||||||
|
Variables: map[string]string{
|
||||||
|
"recipient_name": `<script>alert("x")</script>`,
|
||||||
|
"recharge_url": `javascript:alert(1)`,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotContains(t, preview.Subject, "\r")
|
||||||
|
require.NotContains(t, preview.Subject, "\n")
|
||||||
|
require.Contains(t, preview.Subject, `Low balance for <script>alert("x")</script>Injected`)
|
||||||
|
require.Contains(t, preview.HTML, `<script>alert("x")</script>`)
|
||||||
|
require.NotContains(t, preview.HTML, `javascript:alert(1)`)
|
||||||
|
require.Contains(t, preview.HTML, `href=""`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailTemplateOverrideAndRestore(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newNotificationEmailMemorySettingRepo()
|
||||||
|
svc := NewNotificationEmailService(repo, nil)
|
||||||
|
|
||||||
|
official, err := svc.GetTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "en")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, official.IsCustom)
|
||||||
|
|
||||||
|
updated, err := svc.UpdateTemplate(
|
||||||
|
ctx,
|
||||||
|
NotificationEmailEventBalanceRechargeSuccess,
|
||||||
|
"zh-Hans",
|
||||||
|
"充值完成:{{recharge_amount}}",
|
||||||
|
"<p>{{recipient_name}} 已充值 {{recharge_amount}}</p>",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, updated.IsCustom)
|
||||||
|
require.Equal(t, "zh", updated.Locale)
|
||||||
|
require.Equal(t, "充值完成:{{recharge_amount}}", updated.Subject)
|
||||||
|
require.NotNil(t, updated.UpdatedAt)
|
||||||
|
|
||||||
|
restored, err := svc.RestoreOfficialTemplate(ctx, NotificationEmailEventBalanceRechargeSuccess, "zh")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, restored.IsCustom)
|
||||||
|
require.NotEqual(t, updated.Subject, restored.Subject)
|
||||||
|
_, err = repo.GetValue(ctx, notificationEmailTemplateKey(NotificationEmailEventBalanceRechargeSuccess, "zh"))
|
||||||
|
require.ErrorIs(t, err, ErrSettingNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailTemplateRejectsUnsupportedPlaceholder(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||||
|
|
||||||
|
_, err := svc.UpdateTemplate(
|
||||||
|
ctx,
|
||||||
|
NotificationEmailEventSubscriptionPurchaseSuccess,
|
||||||
|
"en",
|
||||||
|
"Purchased {{not_allowed}}",
|
||||||
|
"<p>{{subscription_group}}</p>",
|
||||||
|
)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "unsupported placeholder")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailAuthTemplatesAreListedAndPreviewable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||||
|
|
||||||
|
infos := svc.ListEventInfos()
|
||||||
|
events := make(map[string]NotificationEmailEventInfo, len(infos))
|
||||||
|
for _, info := range infos {
|
||||||
|
events[info.Event] = info
|
||||||
|
}
|
||||||
|
require.Contains(t, events, NotificationEmailEventAuthVerifyCode)
|
||||||
|
require.Contains(t, events, NotificationEmailEventAuthPasswordReset)
|
||||||
|
require.False(t, events[NotificationEmailEventAuthVerifyCode].Optional)
|
||||||
|
require.False(t, events[NotificationEmailEventAuthPasswordReset].Optional)
|
||||||
|
require.Contains(t, events[NotificationEmailEventAuthVerifyCode].Placeholders, "verification_code")
|
||||||
|
require.Contains(t, events[NotificationEmailEventAuthPasswordReset].Placeholders, "reset_url")
|
||||||
|
|
||||||
|
verifyPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
|
||||||
|
Event: NotificationEmailEventAuthVerifyCode,
|
||||||
|
Locale: "zh-CN",
|
||||||
|
Variables: map[string]string{
|
||||||
|
"verification_code": "654321",
|
||||||
|
"expires_in_minutes": "15",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, verifyPreview.Subject, "邮箱验证码")
|
||||||
|
require.Contains(t, verifyPreview.HTML, "654321")
|
||||||
|
|
||||||
|
resetPreview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{
|
||||||
|
Event: NotificationEmailEventAuthPasswordReset,
|
||||||
|
Locale: "en",
|
||||||
|
Variables: map[string]string{
|
||||||
|
"reset_url": "https://example.com/reset?token=abc",
|
||||||
|
"expires_in_minutes": "30",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, resetPreview.Subject, "Password reset")
|
||||||
|
require.Contains(t, resetPreview.HTML, "https://example.com/reset?token=abc")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailAdditionalEventsAreListedAndPreviewable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||||
|
|
||||||
|
infos := svc.ListEventInfos()
|
||||||
|
events := make(map[string]NotificationEmailEventInfo, len(infos))
|
||||||
|
for _, info := range infos {
|
||||||
|
events[info.Event] = info
|
||||||
|
}
|
||||||
|
|
||||||
|
checks := []struct {
|
||||||
|
event string
|
||||||
|
placeholder string
|
||||||
|
}{
|
||||||
|
{NotificationEmailEventNotificationEmailVerifyCode, "verification_code"},
|
||||||
|
{NotificationEmailEventAccountQuotaAlert, "account_name"},
|
||||||
|
{NotificationEmailEventContentModerationViolation, "moderation_category"},
|
||||||
|
{NotificationEmailEventContentModerationDisabled, "violation_count"},
|
||||||
|
{NotificationEmailEventOpsAlert, "rule_name"},
|
||||||
|
{NotificationEmailEventOpsScheduledReport, "report_html"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, check := range checks {
|
||||||
|
info, ok := events[check.event]
|
||||||
|
require.Truef(t, ok, "expected %s to be listed", check.event)
|
||||||
|
require.False(t, info.Optional)
|
||||||
|
require.Contains(t, info.Placeholders, check.placeholder)
|
||||||
|
|
||||||
|
preview, err := svc.PreviewTemplate(ctx, NotificationEmailPreviewInput{Event: check.event, Locale: "zh"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, preview.Subject)
|
||||||
|
require.NotEmpty(t, preview.HTML)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailRawHTMLVariablesAreTrustedOnlyForHTMLPlaceholders(t *testing.T) {
|
||||||
|
require.True(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "report_html"))
|
||||||
|
require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsScheduledReport, "recipient_name"))
|
||||||
|
require.False(t, notificationEmailRawHTMLAllowed(NotificationEmailEventOpsAlert, "report_html"))
|
||||||
|
|
||||||
|
preview, err := renderNotificationEmail(
|
||||||
|
NotificationEmailEventOpsScheduledReport,
|
||||||
|
"Report for {{recipient_name}}",
|
||||||
|
`<section>{{report_html}}</section><p>{{recipient_name}}</p>`,
|
||||||
|
map[string]string{
|
||||||
|
"recipient_name": `<script>alert("x")</script>`,
|
||||||
|
"report_html": `<p>escaped report</p>`,
|
||||||
|
},
|
||||||
|
map[string]string{
|
||||||
|
"report_html": `<table><tr><td>trusted report</td></tr></table>`,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, preview.HTML, `<table><tr><td>trusted report</td></tr></table>`)
|
||||||
|
require.NotContains(t, preview.HTML, `escaped report`)
|
||||||
|
require.Contains(t, preview.HTML, `<script>alert("x")</script>`)
|
||||||
|
require.Contains(t, preview.Subject, `<script>alert("x")</script>`)
|
||||||
|
|
||||||
|
preview, err = renderNotificationEmail(
|
||||||
|
NotificationEmailEventOpsScheduledReport,
|
||||||
|
"Recipient {{recipient_name}}",
|
||||||
|
`<p>{{recipient_name}}</p>`,
|
||||||
|
map[string]string{"recipient_name": `<em>escaped</em>`},
|
||||||
|
map[string]string{"recipient_name": `<strong>raw</strong>`},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, preview.HTML, `<em>escaped</em>`)
|
||||||
|
require.NotContains(t, preview.HTML, `<strong>raw</strong>`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailFallbackClassification(t *testing.T) {
|
||||||
|
templateErr := notificationEmailTemplateErr(errors.New("bad template"))
|
||||||
|
configErr := notificationEmailConfigErr(errors.New("missing email service"))
|
||||||
|
deliveryErr := notificationEmailDeliveryErr(errors.New("smtp timeout"))
|
||||||
|
|
||||||
|
require.True(t, shouldFallbackNotificationEmail(templateErr))
|
||||||
|
require.True(t, shouldFallbackNotificationEmail(configErr))
|
||||||
|
require.False(t, shouldFallbackNotificationEmail(deliveryErr))
|
||||||
|
require.True(t, isNotificationEmailDeliveryError(deliveryErr))
|
||||||
|
require.False(t, isNotificationEmailDeliveryError(templateErr))
|
||||||
|
require.False(t, shouldFallbackNotificationEmail(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmailQueueTasksPreserveLocaleHints(t *testing.T) {
|
||||||
|
queue := &EmailQueueService{taskChan: make(chan EmailTask, 2)}
|
||||||
|
require.NoError(t, queue.EnqueueVerifyCode("user@example.com", "Sub2API", "zh-CN"))
|
||||||
|
require.NoError(t, queue.EnqueuePasswordReset("user@example.com", "Sub2API", "https://example.com/reset", "en-US"))
|
||||||
|
|
||||||
|
verifyTask := <-queue.taskChan
|
||||||
|
require.Equal(t, TaskTypeVerifyCode, verifyTask.TaskType)
|
||||||
|
require.Equal(t, "zh-CN", verifyTask.Locale)
|
||||||
|
|
||||||
|
resetTask := <-queue.taskChan
|
||||||
|
require.Equal(t, TaskTypePasswordReset, resetTask.TaskType)
|
||||||
|
require.Equal(t, "en-US", resetTask.Locale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpsScheduledReportDeliverySourceIDIncludesReportIdentity(t *testing.T) {
|
||||||
|
report := &opsScheduledReport{Name: "日报", ReportType: "daily_summary", Schedule: "0 9 * * *"}
|
||||||
|
sourceID := opsScheduledReportDeliverySourceID(report)
|
||||||
|
require.Contains(t, sourceID, "daily_summary")
|
||||||
|
require.Contains(t, sourceID, "日报")
|
||||||
|
require.Contains(t, sourceID, "0 9 * * *")
|
||||||
|
require.NotEqual(t, sourceID, opsScheduledReportDeliverySourceID(&opsScheduledReport{Name: "周报", ReportType: "weekly_summary", Schedule: "0 9 * * 1"}))
|
||||||
|
require.Equal(t, "scheduled_report", opsScheduledReportDeliverySourceID(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailUnsubscribeOnlyAllowsOptionalEvents(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||||
|
|
||||||
|
token, err := svc.createUnsubscribeToken(ctx, "User@Example.com", NotificationEmailEventBalanceLow)
|
||||||
|
require.NoError(t, err)
|
||||||
|
result, err := svc.Unsubscribe(ctx, token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.Done)
|
||||||
|
require.Equal(t, NotificationEmailEventBalanceLow, result.Event)
|
||||||
|
unsubscribed, err := svc.IsUnsubscribed(ctx, "user@example.com", NotificationEmailEventBalanceLow)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, unsubscribed)
|
||||||
|
|
||||||
|
transactionalToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventBalanceRechargeSuccess)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = svc.Unsubscribe(ctx, transactionalToken)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "transactional")
|
||||||
|
|
||||||
|
authToken, err := svc.createUnsubscribeToken(ctx, "user@example.com", NotificationEmailEventAuthVerifyCode)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = svc.Unsubscribe(ctx, authToken)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "transactional")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailLocaleMemoryNormalizesAcceptLanguage(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
svc := NewNotificationEmailService(newNotificationEmailMemorySettingRepo(), nil)
|
||||||
|
|
||||||
|
svc.RememberRecipientLocale(ctx, 42, "User@Example.com", "zh-CN,zh;q=0.9,en;q=0.8")
|
||||||
|
require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 42, "user@example.com"))
|
||||||
|
require.Equal(t, "zh", svc.ResolveRecipientLocale(ctx, 0, "user@example.com"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailDeliveryKeyUsesShortStableHash(t *testing.T) {
|
||||||
|
key := notificationEmailDeliveryKey(
|
||||||
|
NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
"user_subscription",
|
||||||
|
"1234567890",
|
||||||
|
"User@Example.com",
|
||||||
|
"7d",
|
||||||
|
)
|
||||||
|
require.NotEmpty(t, key)
|
||||||
|
require.LessOrEqual(t, len(key), 100)
|
||||||
|
require.True(t, strings.HasPrefix(key, notificationEmailDeliveryKeyPrefix+"v2:"))
|
||||||
|
require.Equal(t, key, notificationEmailDeliveryKey(
|
||||||
|
NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
"user_subscription",
|
||||||
|
"1234567890",
|
||||||
|
"user@example.com",
|
||||||
|
"7d",
|
||||||
|
))
|
||||||
|
require.NotEqual(t, key, notificationEmailDeliveryKey(
|
||||||
|
NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
"user_subscription",
|
||||||
|
"1234567890",
|
||||||
|
"user@example.com",
|
||||||
|
"3d",
|
||||||
|
))
|
||||||
|
|
||||||
|
legacyKey := legacyNotificationEmailDeliveryKey(
|
||||||
|
NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
"user_subscription",
|
||||||
|
"1234567890",
|
||||||
|
"user@example.com",
|
||||||
|
"7d",
|
||||||
|
)
|
||||||
|
require.Greater(t, len(legacyKey), 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailPreferenceKeyUsesShortStableHashAndReadsLegacyKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newNotificationEmailMemorySettingRepo()
|
||||||
|
svc := NewNotificationEmailService(repo, nil)
|
||||||
|
|
||||||
|
key := notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "User@Example.com")
|
||||||
|
require.NotEmpty(t, key)
|
||||||
|
require.LessOrEqual(t, len(key), 100)
|
||||||
|
require.True(t, strings.HasPrefix(key, notificationEmailPreferenceKeyPrefix+"v2:"))
|
||||||
|
require.Equal(t, key, notificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com"))
|
||||||
|
|
||||||
|
legacyKey := legacyNotificationEmailPreferenceKey(NotificationEmailEventSubscriptionExpiryReminder, "user@example.com")
|
||||||
|
require.Greater(t, len(legacyKey), 100)
|
||||||
|
require.NoError(t, repo.Set(ctx, legacyKey, "unsubscribed"))
|
||||||
|
|
||||||
|
unsubscribed, err := svc.IsUnsubscribed(ctx, "User@Example.com", NotificationEmailEventSubscriptionExpiryReminder)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, unsubscribed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailSendDeduplicatesSubscriptionExpiryReminder(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newNotificationEmailMemorySettingRepo()
|
||||||
|
smtpServer := startNotificationEmailTestSMTPServer(t)
|
||||||
|
require.NoError(t, repo.SetMultiple(ctx, smtpServer.settings()))
|
||||||
|
|
||||||
|
emailSvc := NewEmailService(repo, nil)
|
||||||
|
svc := NewNotificationEmailService(repo, emailSvc)
|
||||||
|
input := NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
RecipientEmail: "User@Example.com",
|
||||||
|
RecipientName: "User",
|
||||||
|
UserID: 42,
|
||||||
|
SourceType: "user_subscription",
|
||||||
|
SourceID: "1234567890",
|
||||||
|
ReminderKey: "7d",
|
||||||
|
Variables: map[string]string{
|
||||||
|
"subscription_group": "Codex",
|
||||||
|
"expiry_time": "2026-05-27 12:00",
|
||||||
|
"days_remaining": "7",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, svc.Send(ctx, input))
|
||||||
|
require.Equal(t, int64(1), smtpServer.messageCount())
|
||||||
|
|
||||||
|
key := notificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey)
|
||||||
|
require.LessOrEqual(t, len(key), 100)
|
||||||
|
_, err := repo.GetValue(ctx, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, svc.Send(ctx, input))
|
||||||
|
require.Equal(t, int64(1), smtpServer.messageCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailSendRespectsLegacyDeliveryKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newNotificationEmailMemorySettingRepo()
|
||||||
|
svc := NewNotificationEmailService(repo, nil)
|
||||||
|
input := NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
RecipientEmail: "user@example.com",
|
||||||
|
SourceType: "user_subscription",
|
||||||
|
SourceID: "1234567890",
|
||||||
|
ReminderKey: "7d",
|
||||||
|
}
|
||||||
|
legacyKey := legacyNotificationEmailDeliveryKey(input.Event, input.SourceType, input.SourceID, input.RecipientEmail, input.ReminderKey)
|
||||||
|
require.NoError(t, repo.Set(ctx, legacyKey, "sent"))
|
||||||
|
|
||||||
|
require.NoError(t, svc.Send(ctx, input))
|
||||||
|
}
|
||||||
|
|
||||||
|
type notificationEmailMemorySettingRepo struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNotificationEmailMemorySettingRepo() *notificationEmailMemorySettingRepo {
|
||||||
|
return ¬ificationEmailMemorySettingRepo{values: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
value, ok := r.values[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrSettingNotFound
|
||||||
|
}
|
||||||
|
return &Setting{Key: key, Value: value}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
setting, err := r.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return setting.Value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) Set(_ context.Context, key, value string) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.values[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
out := make(map[string]string, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if value, ok := r.values[key]; ok {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
for key, value := range settings {
|
||||||
|
r.values[key] = value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
out := make(map[string]string, len(r.values))
|
||||||
|
for key, value := range r.values {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *notificationEmailMemorySettingRepo) Delete(_ context.Context, key string) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if _, ok := r.values[key]; !ok {
|
||||||
|
return ErrSettingNotFound
|
||||||
|
}
|
||||||
|
delete(r.values, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotificationEmailMemorySettingRepoSatisfiesInterface(t *testing.T) {
|
||||||
|
var _ SettingRepository = (*notificationEmailMemorySettingRepo)(nil)
|
||||||
|
require.False(t, strings.Contains(notificationEmailPreferenceKey(NotificationEmailEventBalanceLow, "User@Example.com"), "User@Example.com"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type notificationEmailTestSMTPServer struct {
|
||||||
|
listener net.Listener
|
||||||
|
wg sync.WaitGroup
|
||||||
|
messages atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func startNotificationEmailTestSMTPServer(t *testing.T) *notificationEmailTestSMTPServer {
|
||||||
|
t.Helper()
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
server := ¬ificationEmailTestSMTPServer{listener: listener}
|
||||||
|
server.wg.Add(1)
|
||||||
|
go server.serve()
|
||||||
|
t.Cleanup(server.close)
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *notificationEmailTestSMTPServer) settings() map[string]string {
|
||||||
|
host, port, _ := net.SplitHostPort(s.listener.Addr().String())
|
||||||
|
return map[string]string{
|
||||||
|
SettingKeySMTPHost: host,
|
||||||
|
SettingKeySMTPPort: port,
|
||||||
|
SettingKeySMTPUsername: "user",
|
||||||
|
SettingKeySMTPPassword: "password",
|
||||||
|
SettingKeySMTPFrom: "noreply@example.com",
|
||||||
|
SettingKeySMTPFromName: "Sub2API",
|
||||||
|
SettingKeySMTPUseTLS: "false",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *notificationEmailTestSMTPServer) messageCount() int64 {
|
||||||
|
return s.messages.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *notificationEmailTestSMTPServer) close() {
|
||||||
|
_ = s.listener.Close()
|
||||||
|
s.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *notificationEmailTestSMTPServer) serve() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
for {
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.handleConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *notificationEmailTestSMTPServer) handleConn(conn net.Conn) {
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||||
|
writeLine := func(line string) bool {
|
||||||
|
if _, err := rw.WriteString(line + "\r\n"); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rw.Flush() == nil
|
||||||
|
}
|
||||||
|
if !writeLine("220 localhost ESMTP") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
line, err := rw.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cmd := strings.ToUpper(strings.TrimRight(line, "\r\n"))
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(cmd, "EHLO"), strings.HasPrefix(cmd, "HELO"):
|
||||||
|
if _, err := rw.WriteString("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := rw.Flush(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(cmd, "AUTH"):
|
||||||
|
if !writeLine("235 2.7.0 Authentication successful") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(cmd, "MAIL FROM:"):
|
||||||
|
if !writeLine("250 2.1.0 OK") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(cmd, "RCPT TO:"):
|
||||||
|
if !writeLine("250 2.1.5 OK") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(cmd, "DATA"):
|
||||||
|
if !writeLine("354 End data with <CR><LF>.<CR><LF>") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
dataLine, err := rw.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimRight(dataLine, "\r\n") == "." {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.messages.Add(1)
|
||||||
|
if !writeLine("250 2.0.0 OK") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(cmd, "QUIT"):
|
||||||
|
_ = writeLine("221 2.0.0 Bye")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if !writeLine("250 OK") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,428 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// forwardResponsesViaRawChatCompletions serves /v1/responses clients through an
|
||||||
|
// upstream that only supports /v1/chat/completions.
|
||||||
|
func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
body []byte,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
var responsesReq apicompat.ResponsesRequest
|
||||||
|
if err := json.Unmarshal(body, &responsesReq); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"message": "Failed to parse request body",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("parse responses request: %w", err)
|
||||||
|
}
|
||||||
|
originalModel := strings.TrimSpace(responsesReq.Model)
|
||||||
|
if originalModel == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"message": "model is required",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("missing model in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientStream := responsesReq.Stream
|
||||||
|
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
|
||||||
|
serviceTier := extractOpenAIServiceTierFromBody(body)
|
||||||
|
|
||||||
|
chatReq, err := apicompat.ResponsesToChatCompletionsRequest(&responsesReq)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"message": err.Error(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("convert responses to chat completions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
billingModel := resolveOpenAIForwardModel(account, originalModel, "")
|
||||||
|
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
|
||||||
|
chatReq.Model = upstreamModel
|
||||||
|
if clientStream {
|
||||||
|
chatReq.StreamOptions = &apicompat.ChatStreamOptions{IncludeUsage: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
chatBody, err := json.Marshal(chatReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal chat completions fallback request: %w", err)
|
||||||
|
}
|
||||||
|
chatBody, err = s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, chatBody)
|
||||||
|
if err != nil {
|
||||||
|
var blocked *OpenAIFastBlockedError
|
||||||
|
if errors.As(err, &blocked) {
|
||||||
|
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if serviceTier == nil {
|
||||||
|
serviceTier = extractOpenAIServiceTierFromBody(chatBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.L().Debug("openai responses: forwarding via raw chat completions",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("original_model", originalModel),
|
||||||
|
zap.String("billing_model", billingModel),
|
||||||
|
zap.String("upstream_model", upstreamModel),
|
||||||
|
zap.Bool("stream", clientStream),
|
||||||
|
)
|
||||||
|
|
||||||
|
apiKey := account.GetOpenAIApiKey()
|
||||||
|
if apiKey == "" {
|
||||||
|
return nil, fmt.Errorf("account %d missing api_key", account.ID)
|
||||||
|
}
|
||||||
|
baseURL := account.GetOpenAIBaseURL()
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com"
|
||||||
|
}
|
||||||
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid base_url: %w", err)
|
||||||
|
}
|
||||||
|
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||||
|
|
||||||
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
|
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(chatBody))
|
||||||
|
releaseUpstreamCtx()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
|
}
|
||||||
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||||
|
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
if clientStream {
|
||||||
|
upstreamReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
} else {
|
||||||
|
upstreamReq.Header.Set("Accept", "application/json")
|
||||||
|
}
|
||||||
|
for key, values := range c.Request.Header {
|
||||||
|
lowerKey := strings.ToLower(key)
|
||||||
|
if openaiCCRawAllowedHeaders[lowerKey] {
|
||||||
|
for _, v := range values {
|
||||||
|
upstreamReq.Header.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
|
||||||
|
upstreamReq.Header.Set("user-agent", customUA)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
}
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: respBody,
|
||||||
|
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.handleErrorResponse(ctx, resp, c, account, chatBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientStream {
|
||||||
|
return s.streamChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||||
|
}
|
||||||
|
return s.bufferChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) bufferChatCompletionsAsResponses(
|
||||||
|
c *gin.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
originalModel string,
|
||||||
|
billingModel string,
|
||||||
|
upstreamModel string,
|
||||||
|
reasoningEffort *string,
|
||||||
|
serviceTier *string,
|
||||||
|
startTime time.Time,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": "Failed to read upstream response",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ccResp apicompat.ChatCompletionsResponse
|
||||||
|
if err := json.Unmarshal(respBody, &ccResp); err != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": "Failed to parse upstream response",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("parse chat completions response: %w", err)
|
||||||
|
}
|
||||||
|
responsesResp := apicompat.ChatCompletionsResponseToResponses(&ccResp, originalModel)
|
||||||
|
|
||||||
|
usage := OpenAIUsage{}
|
||||||
|
if parsed, ok := extractOpenAIUsageFromJSONBytes(respBody); ok {
|
||||||
|
usage = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.responseHeaderFilter != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, responsesResp)
|
||||||
|
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: billingModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
ReasoningEffort: reasoningEffort,
|
||||||
|
ServiceTier: serviceTier,
|
||||||
|
Stream: false,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) streamChatCompletionsAsResponses(
|
||||||
|
c *gin.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
originalModel string,
|
||||||
|
billingModel string,
|
||||||
|
upstreamModel string,
|
||||||
|
reasoningEffort *string,
|
||||||
|
serviceTier *string,
|
||||||
|
startTime time.Time,
|
||||||
|
) (*OpenAIForwardResult, error) {
|
||||||
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
headersWritten := false
|
||||||
|
writeStreamHeaders := func() {
|
||||||
|
if headersWritten {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
headersWritten = true
|
||||||
|
if s.responseHeaderFilter != nil {
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
state := apicompat.NewChatCompletionsToResponsesStreamState(originalModel)
|
||||||
|
var usage OpenAIUsage
|
||||||
|
var firstTokenMs *int
|
||||||
|
clientDisconnected := false
|
||||||
|
sawDone := false
|
||||||
|
|
||||||
|
writeEvents := func(events []apicompat.ResponsesStreamEvent) {
|
||||||
|
if clientDisconnected || len(events) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeStreamHeaders()
|
||||||
|
for _, event := range events {
|
||||||
|
sse, err := apicompat.ResponsesEventToSSE(event)
|
||||||
|
if err != nil {
|
||||||
|
logger.L().Warn("openai responses chat fallback: failed to marshal stream event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Debug("openai responses chat fallback: client disconnected, continuing to drain upstream for billing",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload = strings.TrimSpace(payload)
|
||||||
|
if payload == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if payload == "[DONE]" {
|
||||||
|
sawDone = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if u := extractCCStreamUsage(payload); u != nil {
|
||||||
|
usage = *u
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk apicompat.ChatCompletionsChunk
|
||||||
|
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||||
|
logger.L().Warn("openai responses chat fallback: failed to parse chat stream chunk",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if firstTokenMs == nil && !isOpenAIChatUsageOnlyStreamChunk(payload) && chatChunkStartsResponsesOutput(&chunk) {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
writeEvents(apicompat.ChatCompletionsChunkToResponsesEvents(&chunk, state))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn("openai responses chat fallback: stream read error",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: billingModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
ReasoningEffort: reasoningEffort,
|
||||||
|
ServiceTier: serviceTier,
|
||||||
|
Stream: true,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeEvents(apicompat.FinalizeChatCompletionsResponsesStream(state))
|
||||||
|
if !clientDisconnected {
|
||||||
|
writeStreamHeaders()
|
||||||
|
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
}
|
||||||
|
if !clientDisconnected {
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sawDone {
|
||||||
|
logger.L().Debug("openai responses chat fallback: upstream stream ended without done sentinel",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OpenAIForwardResult{
|
||||||
|
RequestID: requestID,
|
||||||
|
Usage: usage,
|
||||||
|
Model: originalModel,
|
||||||
|
BillingModel: billingModel,
|
||||||
|
UpstreamModel: upstreamModel,
|
||||||
|
ReasoningEffort: reasoningEffort,
|
||||||
|
ServiceTier: serviceTier,
|
||||||
|
Stream: true,
|
||||||
|
Duration: time.Since(startTime),
|
||||||
|
FirstTokenMs: firstTokenMs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatChunkStartsResponsesOutput(chunk *apicompat.ChatCompletionsChunk) bool {
|
||||||
|
if chunk == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, choice := range chunk.Choices {
|
||||||
|
if choice.Delta.Content != nil || choice.Delta.ReasoningContent != nil || len(choice.Delta.ToolCalls) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@ -0,0 +1,145 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestForwardResponses_ForceChatCompletionsRoutesNonStreamingToChatCompletions(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_chat_json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`{"id":"chatcmpl_json","object":"chat.completion","model":"gpt-5.4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5,"prompt_tokens_details":{"cached_tokens":1}}}`,
|
||||||
|
)),
|
||||||
|
}}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: rawChatCompletionsTestConfig(),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||||
|
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String())
|
||||||
|
require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
|
||||||
|
require.Equal(t, "response", gjson.Get(rec.Body.String(), "object").String())
|
||||||
|
require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String())
|
||||||
|
require.Equal(t, 3, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 2, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, result.Usage.CacheReadInputTokens)
|
||||||
|
require.False(t, result.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardResponses_ForceChatCompletionsRoutesStreamingToChatCompletions(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":true}`)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`,
|
||||||
|
"",
|
||||||
|
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"he"},"finish_reason":null}]}`,
|
||||||
|
"",
|
||||||
|
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"llo"},"finish_reason":null}]}`,
|
||||||
|
"",
|
||||||
|
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||||
|
"",
|
||||||
|
`data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_resp_chat_stream"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: rawChatCompletionsTestConfig(),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||||
|
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||||
|
require.Contains(t, rec.Body.String(), "event: response.output_text.delta")
|
||||||
|
require.Contains(t, rec.Body.String(), `"delta":"he"`)
|
||||||
|
require.Contains(t, rec.Body.String(), "event: response.completed")
|
||||||
|
require.Contains(t, rec.Body.String(), `"input_tokens":4`)
|
||||||
|
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||||
|
require.Equal(t, 4, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||||
|
require.True(t, result.Stream)
|
||||||
|
require.NotNil(t, result.FirstTokenMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardResponses_AutoSupportedAccountStillUsesResponsesEndpoint(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_native"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`{"id":"resp_native","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"ok"}],"status":"completed"}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}`,
|
||||||
|
)),
|
||||||
|
}}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: rawChatCompletionsTestConfig(),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := rawChatCompletionsTestAccount()
|
||||||
|
account.Extra = map[string]any{
|
||||||
|
openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeAuto),
|
||||||
|
openai_compat.ExtraKeyResponsesSupported: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "http://upstream.example/v1/responses", upstream.lastReq.URL.String())
|
||||||
|
require.True(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(upstream.lastBody, "messages").Exists())
|
||||||
|
require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func forceChatResponsesFallbackAccount() *Account {
|
||||||
|
account := rawChatCompletionsTestAccount()
|
||||||
|
account.Extra = map[string]any{
|
||||||
|
openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeForceChatCompletions),
|
||||||
|
}
|
||||||
|
return account
|
||||||
|
}
|
||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
@ -2018,6 +2019,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
originalBody := body
|
originalBody := body
|
||||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
|
|
||||||
|
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||||
|
return s.forwardResponsesViaRawChatCompletions(ctx, c, account, body)
|
||||||
|
}
|
||||||
|
|
||||||
compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body)
|
compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body)
|
||||||
setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)
|
setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)
|
||||||
|
|
||||||
@ -3231,6 +3237,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
|||||||
req.Header.Set("user-agent", codexCLIUserAgent)
|
req.Header.Set("user-agent", codexCLIUserAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 浏览器型 UA 兜底:仅 OAuth(ChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器
|
||||||
|
// (Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA,避免 Cloudflare 触发 JS 质询。
|
||||||
|
s.overrideBrowserUserAgent(ctx, account, req)
|
||||||
|
|
||||||
if req.Header.Get("content-type") == "" {
|
if req.Header.Get("content-type") == "" {
|
||||||
req.Header.Set("content-type", "application/json")
|
req.Header.Set("content-type", "application/json")
|
||||||
}
|
}
|
||||||
@ -3947,6 +3957,10 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
req.Header.Set("user-agent", codexCLIUserAgent)
|
req.Header.Set("user-agent", codexCLIUserAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 浏览器型 UA 兜底:仅 OAuth(ChatGPT 内部接口)账号生效,若最终 user-agent 仍为浏览器
|
||||||
|
// (Chrome/Firefox/Safari/Edge 等),替换为后台配置的 Codex UA,避免 Cloudflare 触发 JS 质询。
|
||||||
|
s.overrideBrowserUserAgent(ctx, account, req)
|
||||||
|
|
||||||
// Ensure required headers exist
|
// Ensure required headers exist
|
||||||
if req.Header.Get("content-type") == "" {
|
if req.Header.Get("content-type") == "" {
|
||||||
req.Header.Set("content-type", "application/json")
|
req.Header.Set("content-type", "application/json")
|
||||||
@ -3955,6 +3969,30 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// overrideBrowserUserAgent 检查请求的最终 user-agent,若为浏览器 UA 则替换为后台配置的 Codex UA。
|
||||||
|
// 用于规避 Cloudflare 对浏览器型 UA 在 ChatGPT 内部接口上的访问质询。
|
||||||
|
// 影响范围严格限定:仅 OAuth(Codex/ChatGPT 内部接口)账号生效;API Key 等其他账号原样透传。
|
||||||
|
// 仅在识别为浏览器(Mozilla/...)时改写,其他 CLI/工具 UA 不动。
|
||||||
|
func (s *OpenAIGatewayService) overrideBrowserUserAgent(ctx context.Context, account *Account, req *http.Request) {
|
||||||
|
if req == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if account.Type != AccountTypeOAuth {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
currentUA := req.Header.Get("user-agent")
|
||||||
|
if !openai.IsBrowserUserAgent(currentUA) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
codexUA := DefaultOpenAICodexUserAgent
|
||||||
|
if s != nil && s.settingService != nil {
|
||||||
|
if v := strings.TrimSpace(s.settingService.GetOpenAICodexUserAgent(ctx)); v != "" {
|
||||||
|
codexUA = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Header.Set("user-agent", codexUA)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleErrorResponse(
|
func (s *OpenAIGatewayService) handleErrorResponse(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
|
|||||||
@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
|||||||
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
|
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
|
||||||
tool, _ = sjson.SetBytes(tool, "action", action)
|
tool, _ = sjson.SetBytes(tool, "action", action)
|
||||||
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
|
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
|
||||||
|
if shouldPassOpenAIImagesN(toolModel, parsed.N) {
|
||||||
|
tool, _ = sjson.SetBytes(tool, "n", parsed.N)
|
||||||
|
}
|
||||||
|
|
||||||
for _, field := range []struct {
|
for _, field := range []struct {
|
||||||
path string
|
path string
|
||||||
@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldPassOpenAIImagesN(model string, n int) bool {
|
||||||
|
if n <= 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3")
|
||||||
|
}
|
||||||
|
|
||||||
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
|
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
|
||||||
if gjson.GetBytes(payload, "type").String() != "response.completed" {
|
if gjson.GetBytes(payload, "type").String() != "response.completed" {
|
||||||
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
|
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
|
||||||
@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
|||||||
account.Type,
|
account.Type,
|
||||||
len(parsed.Uploads),
|
len(parsed.Uploads),
|
||||||
)
|
)
|
||||||
if parsed.N > 1 {
|
|
||||||
logger.LegacyPrintf(
|
|
||||||
"service.openai_gateway",
|
|
||||||
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
|
|
||||||
parsed.N,
|
|
||||||
requestModel,
|
|
||||||
parsed.Endpoint,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
defer releaseUpstreamCtx()
|
defer releaseUpstreamCtx()
|
||||||
|
|
||||||
|
|||||||
@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string)
|
|||||||
return openAIImageTestSSEEvent{}, false
|
return openAIImageTestSSEEvent{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
|||||||
"X-Request-Id": []string{"req_img_123"},
|
"X-Request-Id": []string{"req_img_123"},
|
||||||
},
|
},
|
||||||
Body: io.NopCloser(strings.NewReader(
|
Body: io.NopCloser(strings.NewReader(
|
||||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||||
"data: [DONE]\n\n",
|
"data: [DONE]\n\n",
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, "gpt-image-2", result.Model)
|
require.Equal(t, "gpt-image-2", result.Model)
|
||||||
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||||
require.Equal(t, 1, result.ImageCount)
|
require.Equal(t, 3, result.ImageCount)
|
||||||
require.Equal(t, 11, result.Usage.InputTokens)
|
require.Equal(t, 11, result.Usage.InputTokens)
|
||||||
require.Equal(t, 22, result.Usage.OutputTokens)
|
require.Equal(t, 22, result.Usage.OutputTokens)
|
||||||
require.Equal(t, 7, result.Usage.ImageOutputTokens)
|
require.Equal(t, 7, result.Usage.ImageOutputTokens)
|
||||||
@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
|||||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
||||||
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
|
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
|
||||||
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
|
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
|
||||||
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
|
require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int())
|
||||||
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
|
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
|
||||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3)
|
||||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
|
require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String())
|
||||||
|
require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String())
|
||||||
|
require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||||
|
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||||
@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
|
|||||||
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
|
func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) {
|
||||||
parsed := &OpenAIImagesRequest{
|
parsed := &OpenAIImagesRequest{
|
||||||
Endpoint: openAIImagesGenerationsEndpoint,
|
Endpoint: openAIImagesGenerationsEndpoint,
|
||||||
Model: "gpt-image-2",
|
Model: "gpt-image-2",
|
||||||
@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t
|
|||||||
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, body)
|
require.NotNil(t, body)
|
||||||
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int())
|
||||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
|
||||||
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
|
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) {
|
||||||
|
parsed := &OpenAIImagesRequest{
|
||||||
|
Endpoint: openAIImagesGenerationsEndpoint,
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Prompt: "draw a cat",
|
||||||
|
N: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, body)
|
||||||
|
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||||
|
require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
|
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
|
||||||
parsed := &OpenAIImagesRequest{
|
parsed := &OpenAIImagesRequest{
|
||||||
Endpoint: openAIImagesEditsEndpoint,
|
Endpoint: openAIImagesEditsEndpoint,
|
||||||
|
|||||||
@ -686,6 +686,21 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
|
|||||||
if !s.emailLimiter.Allow(time.Now().UTC()) {
|
if !s.emailLimiter.Allow(time.Now().UTC()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if s.emailService.notificationEmailService != nil {
|
||||||
|
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventOpsAlert,
|
||||||
|
RecipientEmail: addr,
|
||||||
|
RecipientName: emailRecipientName(addr),
|
||||||
|
SourceType: "ops_alert",
|
||||||
|
SourceID: fmt.Sprintf("%d", event.ID),
|
||||||
|
Variables: opsAlertEmailVariables(rule, event),
|
||||||
|
}); err == nil {
|
||||||
|
anySent = true
|
||||||
|
continue
|
||||||
|
} else if !shouldFallbackNotificationEmail(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
|
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
|
||||||
// Ignore per-recipient failures; continue best-effort.
|
// Ignore per-recipient failures; continue best-effort.
|
||||||
continue
|
continue
|
||||||
@ -699,6 +714,46 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
|
|||||||
return anySent
|
return anySent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func opsAlertEmailVariables(rule *OpsAlertRule, event *OpsAlertEvent) map[string]string {
|
||||||
|
variables := map[string]string{
|
||||||
|
"rule_name": "-",
|
||||||
|
"severity": "-",
|
||||||
|
"alert_status": "-",
|
||||||
|
"metric_type": "-",
|
||||||
|
"operator": "-",
|
||||||
|
"metric_value": "-",
|
||||||
|
"threshold_value": "-",
|
||||||
|
"triggered_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"alert_description": "-",
|
||||||
|
}
|
||||||
|
if rule != nil {
|
||||||
|
variables["rule_name"] = strings.TrimSpace(rule.Name)
|
||||||
|
variables["severity"] = strings.TrimSpace(rule.Severity)
|
||||||
|
variables["metric_type"] = strings.TrimSpace(rule.MetricType)
|
||||||
|
variables["operator"] = strings.TrimSpace(rule.Operator)
|
||||||
|
variables["threshold_value"] = fmt.Sprintf("%.2f", rule.Threshold)
|
||||||
|
if strings.TrimSpace(rule.Description) != "" {
|
||||||
|
variables["alert_description"] = strings.TrimSpace(rule.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if event != nil {
|
||||||
|
variables["alert_status"] = strings.TrimSpace(event.Status)
|
||||||
|
if event.MetricValue != nil {
|
||||||
|
variables["metric_value"] = fmt.Sprintf("%.2f", *event.MetricValue)
|
||||||
|
}
|
||||||
|
if event.ThresholdValue != nil {
|
||||||
|
variables["threshold_value"] = fmt.Sprintf("%.2f", *event.ThresholdValue)
|
||||||
|
}
|
||||||
|
if !event.FiredAt.IsZero() {
|
||||||
|
variables["triggered_at"] = event.FiredAt.UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(event.Description) != "" {
|
||||||
|
variables["alert_description"] = strings.TrimSpace(event.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return variables
|
||||||
|
}
|
||||||
|
|
||||||
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||||
if rule == nil || event == nil {
|
if rule == nil || event == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@ -337,6 +337,7 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
|||||||
}
|
}
|
||||||
|
|
||||||
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
|
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
|
||||||
|
templateVariables := opsScheduledReportEmailVariables(report, now)
|
||||||
|
|
||||||
attempts := 0
|
attempts := 0
|
||||||
for _, to := range recipients {
|
for _, to := range recipients {
|
||||||
@ -345,6 +346,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
attempts++
|
attempts++
|
||||||
|
if s.emailService.notificationEmailService != nil {
|
||||||
|
if err := s.emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventOpsScheduledReport,
|
||||||
|
RecipientEmail: addr,
|
||||||
|
RecipientName: emailRecipientName(addr),
|
||||||
|
SourceType: "ops_scheduled_report",
|
||||||
|
SourceID: opsScheduledReportDeliverySourceID(report),
|
||||||
|
ReminderKey: now.UTC().Format("2006-01-02T15:04"),
|
||||||
|
Variables: templateVariables,
|
||||||
|
RawHTMLVariables: map[string]string{
|
||||||
|
"report_html": content,
|
||||||
|
},
|
||||||
|
}); err == nil {
|
||||||
|
continue
|
||||||
|
} else if !shouldFallbackNotificationEmail(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
|
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
|
||||||
// Ignore per-recipient failures; continue best-effort.
|
// Ignore per-recipient failures; continue best-effort.
|
||||||
continue
|
continue
|
||||||
@ -353,6 +372,46 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
|||||||
return attempts, nil
|
return attempts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func opsScheduledReportDeliverySourceID(report *opsScheduledReport) string {
|
||||||
|
if report == nil {
|
||||||
|
return "scheduled_report"
|
||||||
|
}
|
||||||
|
parts := []string{
|
||||||
|
strings.TrimSpace(report.ReportType),
|
||||||
|
strings.TrimSpace(report.Name),
|
||||||
|
strings.TrimSpace(report.Schedule),
|
||||||
|
}
|
||||||
|
joined := strings.Trim(strings.Join(parts, ":"), ":")
|
||||||
|
if joined == "" {
|
||||||
|
return "scheduled_report"
|
||||||
|
}
|
||||||
|
return joined
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsScheduledReportEmailVariables(report *opsScheduledReport, now time.Time) map[string]string {
|
||||||
|
end := now.UTC()
|
||||||
|
start := end
|
||||||
|
name := "Ops report"
|
||||||
|
reportType := "scheduled_report"
|
||||||
|
if report != nil {
|
||||||
|
if strings.TrimSpace(report.Name) != "" {
|
||||||
|
name = strings.TrimSpace(report.Name)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(report.ReportType) != "" {
|
||||||
|
reportType = strings.TrimSpace(report.ReportType)
|
||||||
|
}
|
||||||
|
if report.TimeRange > 0 {
|
||||||
|
start = end.Add(-report.TimeRange)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]string{
|
||||||
|
"report_name": name,
|
||||||
|
"report_type": reportType,
|
||||||
|
"report_start_time": start.Format(time.RFC3339),
|
||||||
|
"report_end_time": end.Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
|
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
|
||||||
if s == nil || s.opsService == nil || report == nil {
|
if s == nil || s.opsService == nil || report == nil {
|
||||||
return "", fmt.Errorf("service not initialized")
|
return "", fmt.Errorf("service not initialized")
|
||||||
|
|||||||
@ -310,9 +310,87 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde
|
|||||||
"creditedAmount": o.Amount,
|
"creditedAmount": o.Amount,
|
||||||
"payAmount": o.PayAmount,
|
"payAmount": o.PayAmount,
|
||||||
})
|
})
|
||||||
|
s.dispatchPaymentFulfillmentNotification(o, auditAction)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PaymentService) dispatchPaymentFulfillmentNotification(o *dbent.PaymentOrder, auditAction string) {
|
||||||
|
if s == nil || s.notificationEmailService == nil || o == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||||
|
defer cancel()
|
||||||
|
var err error
|
||||||
|
switch auditAction {
|
||||||
|
case "RECHARGE_SUCCESS":
|
||||||
|
err = s.sendBalanceRechargeSuccessNotification(ctx, o)
|
||||||
|
case "SUBSCRIPTION_SUCCESS":
|
||||||
|
err = s.sendSubscriptionPurchaseSuccessNotification(ctx, o)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("payment fulfillment notification email failed", "order_id", o.ID, "action", auditAction, "err", err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PaymentService) sendBalanceRechargeSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||||
|
currentBalance := ""
|
||||||
|
if s.userRepo != nil {
|
||||||
|
if user, err := s.userRepo.GetByID(ctx, o.UserID); err == nil && user != nil {
|
||||||
|
currentBalance = fmt.Sprintf("%.2f", user.Balance)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventBalanceRechargeSuccess,
|
||||||
|
RecipientEmail: o.UserEmail,
|
||||||
|
RecipientName: firstNonEmpty(o.UserName, o.UserEmail),
|
||||||
|
UserID: o.UserID,
|
||||||
|
SourceType: "payment_order",
|
||||||
|
SourceID: strconv.FormatInt(o.ID, 10),
|
||||||
|
Variables: map[string]string{
|
||||||
|
"recharge_amount": fmt.Sprintf("%.2f", o.Amount),
|
||||||
|
"current_balance": currentBalance,
|
||||||
|
"order_id": strconv.FormatInt(o.ID, 10),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PaymentService) sendSubscriptionPurchaseSuccessNotification(ctx context.Context, o *dbent.PaymentOrder) error {
|
||||||
|
variables := map[string]string{
|
||||||
|
"subscription_group": "Subscription",
|
||||||
|
"subscription_days": "",
|
||||||
|
"expiry_time": "",
|
||||||
|
"order_id": strconv.FormatInt(o.ID, 10),
|
||||||
|
}
|
||||||
|
if o.SubscriptionDays != nil {
|
||||||
|
variables["subscription_days"] = strconv.Itoa(*o.SubscriptionDays)
|
||||||
|
}
|
||||||
|
if o.SubscriptionGroupID != nil {
|
||||||
|
if s.groupRepo != nil {
|
||||||
|
if group, err := s.groupRepo.GetByID(ctx, *o.SubscriptionGroupID); err == nil && group != nil && strings.TrimSpace(group.Name) != "" {
|
||||||
|
variables["subscription_group"] = group.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.subscriptionSvc != nil {
|
||||||
|
if sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID); err == nil && sub != nil {
|
||||||
|
variables["expiry_time"] = sub.ExpiresAt.Format("2006-01-02 15:04")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventSubscriptionPurchaseSuccess,
|
||||||
|
RecipientEmail: o.UserEmail,
|
||||||
|
RecipientName: firstNonEmpty(o.UserName, o.UserEmail),
|
||||||
|
UserID: o.UserID,
|
||||||
|
SourceType: "payment_order",
|
||||||
|
SourceID: strconv.FormatInt(o.ID, 10),
|
||||||
|
Variables: variables,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
|
func (s *PaymentService) ExecuteSubscriptionFulfillment(ctx context.Context, oid int64) error {
|
||||||
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -48,6 +48,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
|
|||||||
if user.Status != payment.EntityStatusActive {
|
if user.Status != payment.EntityStatusActive {
|
||||||
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
|
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
|
||||||
}
|
}
|
||||||
|
if s.notificationEmailService != nil {
|
||||||
|
s.notificationEmailService.RememberRecipientLocale(ctx, req.UserID, user.Email, req.Locale)
|
||||||
|
}
|
||||||
orderAmount := req.Amount
|
orderAmount := req.Amount
|
||||||
limitAmount := req.Amount
|
limitAmount := req.Amount
|
||||||
if plan != nil {
|
if plan != nil {
|
||||||
|
|||||||
@ -83,6 +83,7 @@ type CreateOrderRequest struct {
|
|||||||
PaymentSource string
|
PaymentSource string
|
||||||
OrderType string
|
OrderType string
|
||||||
PlanID int64
|
PlanID int64
|
||||||
|
Locale string
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateOrderResponse struct {
|
type CreateOrderResponse struct {
|
||||||
@ -174,18 +175,19 @@ type TopUserStat struct {
|
|||||||
// --- Service ---
|
// --- Service ---
|
||||||
|
|
||||||
type PaymentService struct {
|
type PaymentService struct {
|
||||||
providerMu sync.Mutex
|
providerMu sync.Mutex
|
||||||
providersLoaded bool
|
providersLoaded bool
|
||||||
entClient *dbent.Client
|
entClient *dbent.Client
|
||||||
registry *payment.Registry
|
registry *payment.Registry
|
||||||
loadBalancer payment.LoadBalancer
|
loadBalancer payment.LoadBalancer
|
||||||
redeemService *RedeemService
|
redeemService *RedeemService
|
||||||
subscriptionSvc *SubscriptionService
|
subscriptionSvc *SubscriptionService
|
||||||
configService *PaymentConfigService
|
configService *PaymentConfigService
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
resumeService *PaymentResumeService
|
resumeService *PaymentResumeService
|
||||||
affiliateService *AffiliateService
|
affiliateService *AffiliateService
|
||||||
|
notificationEmailService *NotificationEmailService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
|
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
|
||||||
@ -194,6 +196,10 @@ func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, load
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PaymentService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||||
|
s.notificationEmailService = notificationEmailService
|
||||||
|
}
|
||||||
|
|
||||||
// --- Provider Registry ---
|
// --- Provider Registry ---
|
||||||
|
|
||||||
// EnsureProviders lazily initializes the provider registry on first call.
|
// EnsureProviders lazily initializes the provider registry on first call.
|
||||||
|
|||||||
@ -128,6 +128,19 @@ const antigravityUserAgentVersionCacheTTL = 60 * time.Second
|
|||||||
const antigravityUserAgentVersionErrorTTL = 5 * time.Second
|
const antigravityUserAgentVersionErrorTTL = 5 * time.Second
|
||||||
const antigravityUserAgentVersionDBTimeout = 5 * time.Second
|
const antigravityUserAgentVersionDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// DefaultOpenAICodexUserAgent OpenAI Codex 默认 User-Agent(用于规避 Cloudflare 对浏览器 UA 的质询)
|
||||||
|
const DefaultOpenAICodexUserAgent = "codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)"
|
||||||
|
|
||||||
|
// cachedOpenAICodexUserAgent 缓存 OpenAI Codex UA(进程内缓存,60s TTL)
|
||||||
|
type cachedOpenAICodexUserAgent struct {
|
||||||
|
value string
|
||||||
|
expiresAt int64 // unix nano
|
||||||
|
}
|
||||||
|
|
||||||
|
const openAICodexUserAgentCacheTTL = 60 * time.Second
|
||||||
|
const openAICodexUserAgentErrorTTL = 5 * time.Second
|
||||||
|
const openAICodexUserAgentDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||||
type DefaultSubscriptionGroupReader interface {
|
type DefaultSubscriptionGroupReader interface {
|
||||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||||
@ -148,6 +161,8 @@ type SettingService struct {
|
|||||||
webSearchManagerBuilder WebSearchManagerBuilder
|
webSearchManagerBuilder WebSearchManagerBuilder
|
||||||
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
|
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
|
||||||
antigravityUAVersionSF singleflight.Group
|
antigravityUAVersionSF singleflight.Group
|
||||||
|
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
|
||||||
|
openAICodexUASF singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderDefaultGrantSettings struct {
|
type ProviderDefaultGrantSettings struct {
|
||||||
@ -907,6 +922,55 @@ func (s *SettingService) GetAntigravityUserAgentVersion(ctx context.Context) str
|
|||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAICodexUserAgent 返回 OpenAI Codex 上游请求使用的 User-Agent。
|
||||||
|
// 后台设置优先;为空时回退到内置默认值。
|
||||||
|
func (s *SettingService) GetOpenAICodexUserAgent(ctx context.Context) string {
|
||||||
|
fallback := DefaultOpenAICodexUserAgent
|
||||||
|
if s == nil || s.settingRepo == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil {
|
||||||
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
|
return cached.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _, _ := s.openAICodexUASF.Do("openai_codex_user_agent", func() (any, error) {
|
||||||
|
if cached, ok := s.openAICodexUACache.Load().(*cachedOpenAICodexUserAgent); ok && cached != nil {
|
||||||
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
|
return cached.value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAICodexUserAgentDBTimeout)
|
||||||
|
defer cancel()
|
||||||
|
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpenAICodexUserAgent)
|
||||||
|
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||||
|
slog.Warn("failed to get openai codex user agent setting", "error", err)
|
||||||
|
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
|
||||||
|
value: fallback,
|
||||||
|
expiresAt: time.Now().Add(openAICodexUserAgentErrorTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
ua := strings.TrimSpace(value)
|
||||||
|
if ua == "" {
|
||||||
|
ua = fallback
|
||||||
|
}
|
||||||
|
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
|
||||||
|
value: ua,
|
||||||
|
expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return ua, nil
|
||||||
|
})
|
||||||
|
if ua, ok := result.(string); ok && ua != "" {
|
||||||
|
return ua
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
||||||
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
||||||
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||||
@ -1706,6 +1770,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
|
||||||
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
|
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
|
||||||
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
||||||
|
updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent)
|
||||||
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
||||||
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
||||||
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
||||||
@ -1788,6 +1853,15 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
|||||||
version: antigravityUserAgentVersion,
|
version: antigravityUserAgentVersion,
|
||||||
expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(),
|
expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(),
|
||||||
})
|
})
|
||||||
|
s.openAICodexUASF.Forget("openai_codex_user_agent")
|
||||||
|
codexUA := strings.TrimSpace(settings.OpenAICodexUserAgent)
|
||||||
|
if codexUA == "" {
|
||||||
|
codexUA = DefaultOpenAICodexUserAgent
|
||||||
|
}
|
||||||
|
s.openAICodexUACache.Store(&cachedOpenAICodexUserAgent{
|
||||||
|
value: codexUA,
|
||||||
|
expiresAt: time.Now().Add(openAICodexUserAgentCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
|
||||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||||
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
||||||
@ -2529,6 +2603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
|
||||||
SettingKeyRewriteMessageCacheControl: strconv.FormatBool(s.defaultRewriteMessageCacheControl()),
|
SettingKeyRewriteMessageCacheControl: strconv.FormatBool(s.defaultRewriteMessageCacheControl()),
|
||||||
SettingKeyAntigravityUserAgentVersion: "",
|
SettingKeyAntigravityUserAgentVersion: "",
|
||||||
|
SettingKeyOpenAICodexUserAgent: "",
|
||||||
SettingPaymentVisibleMethodAlipaySource: "",
|
SettingPaymentVisibleMethodAlipaySource: "",
|
||||||
SettingPaymentVisibleMethodWxpaySource: "",
|
SettingPaymentVisibleMethodWxpaySource: "",
|
||||||
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
SettingPaymentVisibleMethodAlipayEnabled: "false",
|
||||||
@ -3041,6 +3116,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl()
|
result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl()
|
||||||
}
|
}
|
||||||
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
|
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
|
||||||
|
result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent])
|
||||||
|
|
||||||
// Web search emulation: quick enabled check from the JSON config
|
// Web search emulation: quick enabled check from the JSON config
|
||||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||||
|
|||||||
@ -193,6 +193,7 @@ type SystemSettings struct {
|
|||||||
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
|
||||||
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false)
|
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false)
|
||||||
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
|
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
|
||||||
|
OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent;空值使用内置默认
|
||||||
|
|
||||||
// Web Search Emulation
|
// Web Search Emulation
|
||||||
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
||||||
|
|||||||
@ -2,18 +2,23 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SubscriptionExpiryService periodically updates expired subscription status.
|
// SubscriptionExpiryService periodically updates expired subscription status.
|
||||||
type SubscriptionExpiryService struct {
|
type SubscriptionExpiryService struct {
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
interval time.Duration
|
notificationEmailService *NotificationEmailService
|
||||||
stopCh chan struct{}
|
interval time.Duration
|
||||||
stopOnce sync.Once
|
stopCh chan struct{}
|
||||||
wg sync.WaitGroup
|
stopOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
|
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
|
||||||
@ -24,6 +29,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||||
|
s.notificationEmailService = notificationEmailService
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SubscriptionExpiryService) Start() {
|
func (s *SubscriptionExpiryService) Start() {
|
||||||
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
|
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
|
||||||
return
|
return
|
||||||
@ -68,4 +77,50 @@ func (s *SubscriptionExpiryService) runOnce() {
|
|||||||
if updated > 0 {
|
if updated > 0 {
|
||||||
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
|
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
|
||||||
}
|
}
|
||||||
|
s.sendExpiryReminders(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
|
||||||
|
if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for page := 1; ; page++ {
|
||||||
|
subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[SubscriptionExpiry] List active subscriptions for reminder failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range subs {
|
||||||
|
s.sendExpiryReminderIfDue(ctx, &subs[i])
|
||||||
|
}
|
||||||
|
if pag == nil || page >= pag.Pages || len(subs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) {
|
||||||
|
if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
daysRemaining := sub.DaysRemaining()
|
||||||
|
if daysRemaining != 7 && daysRemaining != 3 && daysRemaining != 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventSubscriptionExpiryReminder,
|
||||||
|
RecipientEmail: sub.User.Email,
|
||||||
|
RecipientName: firstNonEmpty(sub.User.Username, sub.User.Email),
|
||||||
|
UserID: sub.UserID,
|
||||||
|
SourceType: "user_subscription",
|
||||||
|
SourceID: strconv.FormatInt(sub.ID, 10),
|
||||||
|
ReminderKey: fmt.Sprintf("%dd", daysRemaining),
|
||||||
|
Variables: map[string]string{
|
||||||
|
"subscription_group": sub.Group.Name,
|
||||||
|
"expiry_time": sub.ExpiresAt.Format("2006-01-02 15:04"),
|
||||||
|
"days_remaining": strconv.Itoa(daysRemaining),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("[SubscriptionExpiry] Send expiry reminder failed: subscription=%d user=%d err=%v", sub.ID, sub.UserID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -517,7 +517,7 @@ func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendVerifyCode sends an email verification code for TOTP operations
|
// SendVerifyCode sends an email verification code for TOTP operations
|
||||||
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
|
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64, locale ...string) error {
|
||||||
// Check if email verification is enabled
|
// Check if email verification is enabled
|
||||||
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||||
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
|
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
|
||||||
@ -533,5 +533,5 @@ func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
|
|||||||
siteName := s.settingService.GetSiteName(ctx)
|
siteName := s.settingService.GetSiteName(ctx)
|
||||||
|
|
||||||
// Send verification code via queue
|
// Send verification code via queue
|
||||||
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
|
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName, firstEmailLocale(locale))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -324,6 +324,30 @@ func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64,
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAPIKeyDailyUsage returns daily usage stats for a user's API key.
|
||||||
|
func (s *UsageService) GetAPIKeyDailyUsage(ctx context.Context, userID, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.APIKeyDailyUsagePoint, error) {
|
||||||
|
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, "day", userID, apiKeyID, 0, 0, "", nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key daily usage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
points := make([]usagestats.APIKeyDailyUsagePoint, 0, len(trend))
|
||||||
|
for _, row := range trend {
|
||||||
|
points = append(points, usagestats.APIKeyDailyUsagePoint{
|
||||||
|
Date: row.Date,
|
||||||
|
Requests: row.Requests,
|
||||||
|
InputTokens: row.InputTokens,
|
||||||
|
OutputTokens: row.OutputTokens,
|
||||||
|
CacheReadTokens: row.CacheReadTokens,
|
||||||
|
CacheWriteTokens: row.CacheCreationTokens,
|
||||||
|
TotalTokens: row.TotalTokens,
|
||||||
|
Cost: row.Cost,
|
||||||
|
ActualCost: row.ActualCost,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return points, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||||
|
|||||||
@ -1122,7 +1122,7 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendNotifyEmailCode sends a verification code to the extra notification email.
|
// SendNotifyEmailCode sends a verification code to the extra notification email.
|
||||||
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
|
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache, locale ...string) error {
|
||||||
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
|
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1134,7 +1134,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
|
|||||||
|
|
||||||
// Send email first — if SMTP fails, don't write cache or increment counters,
|
// Send email first — if SMTP fails, don't write cache or increment counters,
|
||||||
// so the user is not locked out by cooldown/rate-limit for a code they never received.
|
// so the user is not locked out by cooldown/rate-limit for a code they never received.
|
||||||
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
|
if err := s.sendNotifyVerifyEmail(ctx, emailService, userID, email, code, firstEmailLocale(locale)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1180,13 +1180,33 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sendNotifyVerifyEmail builds and sends the verification email.
|
// sendNotifyVerifyEmail builds and sends the verification email.
|
||||||
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
|
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, userID int64, email, code, locale string) error {
|
||||||
siteName := "Sub2API"
|
siteName := "Sub2API"
|
||||||
if s.settingRepo != nil {
|
if s.settingRepo != nil {
|
||||||
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
|
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
|
||||||
siteName = name
|
siteName = name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if emailService.notificationEmailService != nil {
|
||||||
|
if err := emailService.notificationEmailService.Send(ctx, NotificationEmailSendInput{
|
||||||
|
Event: NotificationEmailEventNotificationEmailVerifyCode,
|
||||||
|
Locale: locale,
|
||||||
|
RecipientEmail: email,
|
||||||
|
RecipientName: emailRecipientName(email),
|
||||||
|
UserID: userID,
|
||||||
|
Variables: map[string]string{
|
||||||
|
"verification_code": code,
|
||||||
|
"expires_in_minutes": strconv.Itoa(int(verifyCodeTTL / time.Minute)),
|
||||||
|
},
|
||||||
|
}); err == nil {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
if !shouldFallbackNotificationEmail(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
slog.Warn("template notification email verification failed; falling back to built-in body", "recipient_hash", notificationEmailHash(email), "err", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
|
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
|
||||||
body := buildNotifyVerifyEmailBody(code, siteName)
|
body := buildNotifyVerifyEmailBody(code, siteName)
|
||||||
return emailService.SendEmail(ctx, email, subject, body)
|
return emailService.SendEmail(ctx, email, subject, body)
|
||||||
|
|||||||
@ -154,8 +154,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
|
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
|
||||||
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService {
|
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
|
||||||
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
|
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
|
||||||
|
svc.SetNotificationEmailService(notificationEmailService)
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@ -484,6 +485,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideOpsCleanupService,
|
ProvideOpsCleanupService,
|
||||||
ProvideOpsScheduledReportService,
|
ProvideOpsScheduledReportService,
|
||||||
NewEmailService,
|
NewEmailService,
|
||||||
|
NewNotificationEmailService,
|
||||||
ProvideEmailQueueService,
|
ProvideEmailQueueService,
|
||||||
NewTurnstileService,
|
NewTurnstileService,
|
||||||
NewSubscriptionService,
|
NewSubscriptionService,
|
||||||
@ -520,7 +522,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewContentModerationService,
|
NewContentModerationService,
|
||||||
NewAffiliateService,
|
NewAffiliateService,
|
||||||
ProvidePaymentConfigService,
|
ProvidePaymentConfigService,
|
||||||
NewPaymentService,
|
ProvidePaymentService,
|
||||||
ProvidePaymentOrderExpiryService,
|
ProvidePaymentOrderExpiryService,
|
||||||
ProvideBalanceNotifyService,
|
ProvideBalanceNotifyService,
|
||||||
ProvideWindsurfAuthService,
|
ProvideWindsurfAuthService,
|
||||||
@ -648,8 +650,17 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ProvideBalanceNotifyService creates BalanceNotifyService
|
// ProvideBalanceNotifyService creates BalanceNotifyService
|
||||||
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
|
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository, notificationEmailService *NotificationEmailService) *BalanceNotifyService {
|
||||||
return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
|
svc := NewBalanceNotifyService(emailService, settingRepo, accountRepo)
|
||||||
|
svc.SetNotificationEmailService(notificationEmailService)
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvidePaymentService creates PaymentService and attaches notification email delivery.
|
||||||
|
func ProvidePaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService, notificationEmailService *NotificationEmailService) *PaymentService {
|
||||||
|
svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo, affiliateService)
|
||||||
|
svc.SetNotificationEmailService(notificationEmailService)
|
||||||
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
|
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
|
||||||
|
|||||||
@ -5,6 +5,45 @@
|
|||||||
import { config } from '@vue/test-utils'
|
import { config } from '@vue/test-utils'
|
||||||
import { vi } from 'vitest'
|
import { vi } from 'vitest'
|
||||||
|
|
||||||
|
function createMemoryStorage(): Storage {
|
||||||
|
const values = new Map<string, string>()
|
||||||
|
|
||||||
|
return {
|
||||||
|
get length() {
|
||||||
|
return values.size
|
||||||
|
},
|
||||||
|
clear() {
|
||||||
|
values.clear()
|
||||||
|
},
|
||||||
|
getItem(key: string) {
|
||||||
|
return values.has(key) ? values.get(key)! : null
|
||||||
|
},
|
||||||
|
key(index: number) {
|
||||||
|
return Array.from(values.keys())[index] ?? null
|
||||||
|
},
|
||||||
|
removeItem(key: string) {
|
||||||
|
values.delete(key)
|
||||||
|
},
|
||||||
|
setItem(key: string, value: string) {
|
||||||
|
values.set(key, String(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof globalThis.localStorage === 'undefined' || typeof globalThis.localStorage.getItem !== 'function') {
|
||||||
|
Object.defineProperty(globalThis, 'localStorage', {
|
||||||
|
configurable: true,
|
||||||
|
value: createMemoryStorage()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof window !== 'undefined' && typeof window.localStorage.getItem !== 'function') {
|
||||||
|
Object.defineProperty(window, 'localStorage', {
|
||||||
|
configurable: true,
|
||||||
|
value: globalThis.localStorage
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Mock requestIdleCallback (Safari < 15 不支持)
|
// Mock requestIdleCallback (Safari < 15 不支持)
|
||||||
if (typeof globalThis.requestIdleCallback === 'undefined') {
|
if (typeof globalThis.requestIdleCallback === 'undefined') {
|
||||||
globalThis.requestIdleCallback = ((callback: IdleRequestCallback) => {
|
globalThis.requestIdleCallback = ((callback: IdleRequestCallback) => {
|
||||||
|
|||||||
@ -505,6 +505,7 @@ export interface SystemSettings {
|
|||||||
enable_anthropic_cache_ttl_1h_injection: boolean;
|
enable_anthropic_cache_ttl_1h_injection: boolean;
|
||||||
rewrite_message_cache_control: boolean;
|
rewrite_message_cache_control: boolean;
|
||||||
antigravity_user_agent_version: string;
|
antigravity_user_agent_version: string;
|
||||||
|
openai_codex_user_agent: string;
|
||||||
web_search_emulation_enabled?: boolean;
|
web_search_emulation_enabled?: boolean;
|
||||||
|
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
@ -726,6 +727,7 @@ export interface UpdateSettingsRequest {
|
|||||||
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
enable_anthropic_cache_ttl_1h_injection?: boolean;
|
||||||
rewrite_message_cache_control?: boolean;
|
rewrite_message_cache_control?: boolean;
|
||||||
antigravity_user_agent_version?: string;
|
antigravity_user_agent_version?: string;
|
||||||
|
openai_codex_user_agent?: string;
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
payment_enabled?: boolean;
|
payment_enabled?: boolean;
|
||||||
risk_control_enabled?: boolean;
|
risk_control_enabled?: boolean;
|
||||||
@ -854,6 +856,105 @@ export async function sendTestEmail(
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== Email Template Settings ====================
|
||||||
|
|
||||||
|
export interface EmailTemplateOption {
|
||||||
|
value: string;
|
||||||
|
label?: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type EmailTemplateEventOption = string | EmailTemplateOption;
|
||||||
|
|
||||||
|
export interface EmailTemplateSummary {
|
||||||
|
event: string;
|
||||||
|
locale: string;
|
||||||
|
subject: string;
|
||||||
|
is_custom?: boolean;
|
||||||
|
updated_at?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EmailTemplateListResponse {
|
||||||
|
events: EmailTemplateEventOption[];
|
||||||
|
locales: string[];
|
||||||
|
templates?: EmailTemplateSummary[];
|
||||||
|
placeholders?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EmailTemplateDetail {
|
||||||
|
event: string;
|
||||||
|
locale: string;
|
||||||
|
subject: string;
|
||||||
|
html: string;
|
||||||
|
is_custom?: boolean;
|
||||||
|
updated_at?: string;
|
||||||
|
placeholders?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UpdateEmailTemplateRequest {
|
||||||
|
subject: string;
|
||||||
|
html: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PreviewEmailTemplateRequest extends UpdateEmailTemplateRequest {
|
||||||
|
event: string;
|
||||||
|
locale: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EmailTemplatePreviewResponse {
|
||||||
|
subject: string;
|
||||||
|
html: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getEmailTemplates(): Promise<EmailTemplateListResponse> {
|
||||||
|
const { data } = await apiClient.get<EmailTemplateListResponse>(
|
||||||
|
"/admin/settings/email-templates",
|
||||||
|
);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getEmailTemplate(
|
||||||
|
event: string,
|
||||||
|
locale: string,
|
||||||
|
): Promise<EmailTemplateDetail> {
|
||||||
|
const { data } = await apiClient.get<EmailTemplateDetail>(
|
||||||
|
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`,
|
||||||
|
);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function updateEmailTemplate(
|
||||||
|
event: string,
|
||||||
|
locale: string,
|
||||||
|
request: UpdateEmailTemplateRequest,
|
||||||
|
): Promise<EmailTemplateDetail> {
|
||||||
|
const { data } = await apiClient.put<EmailTemplateDetail>(
|
||||||
|
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}`,
|
||||||
|
request,
|
||||||
|
);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function restoreOfficialEmailTemplate(
|
||||||
|
event: string,
|
||||||
|
locale: string,
|
||||||
|
): Promise<EmailTemplateDetail> {
|
||||||
|
const { data } = await apiClient.post<EmailTemplateDetail>(
|
||||||
|
`/admin/settings/email-templates/${encodeURIComponent(event)}/${encodeURIComponent(locale)}/restore-official`,
|
||||||
|
);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function previewEmailTemplate(
|
||||||
|
request: PreviewEmailTemplateRequest,
|
||||||
|
): Promise<EmailTemplatePreviewResponse> {
|
||||||
|
const { data } = await apiClient.post<EmailTemplatePreviewResponse>(
|
||||||
|
"/admin/settings/email-template-preview",
|
||||||
|
request,
|
||||||
|
);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Admin API Key status response
|
* Admin API Key status response
|
||||||
*/
|
*/
|
||||||
@ -1160,6 +1261,11 @@ export const settingsAPI = {
|
|||||||
updateSettings,
|
updateSettings,
|
||||||
testSmtpConnection,
|
testSmtpConnection,
|
||||||
sendTestEmail,
|
sendTestEmail,
|
||||||
|
getEmailTemplates,
|
||||||
|
getEmailTemplate,
|
||||||
|
updateEmailTemplate,
|
||||||
|
restoreOfficialEmailTemplate,
|
||||||
|
previewEmailTemplate,
|
||||||
getAdminApiKey,
|
getAdminApiKey,
|
||||||
regenerateAdminApiKey,
|
regenerateAdminApiKey,
|
||||||
deleteAdminApiKey,
|
deleteAdminApiKey,
|
||||||
|
|||||||
@ -69,6 +69,25 @@ export interface ModelStatsResponse {
|
|||||||
end_date: string
|
end_date: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ApiKeyDailyUsagePoint {
|
||||||
|
date: string
|
||||||
|
requests: number
|
||||||
|
input_tokens: number
|
||||||
|
output_tokens: number
|
||||||
|
cache_read_tokens: number
|
||||||
|
cache_write_tokens: number
|
||||||
|
total_tokens: number
|
||||||
|
cost: number
|
||||||
|
actual_cost: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ApiKeyDailyUsageResponse {
|
||||||
|
items: ApiKeyDailyUsagePoint[]
|
||||||
|
days: number
|
||||||
|
start_date: string
|
||||||
|
end_date: string
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List usage logs with optional filters
|
* List usage logs with optional filters
|
||||||
* @param page - Page number (default: 1)
|
* @param page - Page number (default: 1)
|
||||||
@ -234,6 +253,23 @@ export async function getDashboardModels(params?: {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get daily usage details for one API key owned by the current user.
|
||||||
|
* @param apiKeyId - API key ID
|
||||||
|
* @param days - Number of days to include (1-90)
|
||||||
|
* @returns Daily usage detail rows
|
||||||
|
*/
|
||||||
|
export async function getMyApiKeyDailyUsage(
|
||||||
|
apiKeyId: number,
|
||||||
|
days: number = 30
|
||||||
|
): Promise<ApiKeyDailyUsageResponse> {
|
||||||
|
const { data } = await apiClient.get<ApiKeyDailyUsageResponse>(
|
||||||
|
`/user/api-keys/${apiKeyId}/usage/daily`,
|
||||||
|
{ params: { days } }
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
export interface BatchApiKeyUsageStats {
|
export interface BatchApiKeyUsageStats {
|
||||||
api_key_id: number
|
api_key_id: number
|
||||||
today_actual_cost: number
|
today_actual_cost: number
|
||||||
@ -279,6 +315,7 @@ export const usageAPI = {
|
|||||||
getDashboardStats,
|
getDashboardStats,
|
||||||
getDashboardTrend,
|
getDashboardTrend,
|
||||||
getDashboardModels,
|
getDashboardModels,
|
||||||
|
getMyApiKeyDailyUsage,
|
||||||
getDashboardApiKeysUsage
|
getDashboardApiKeysUsage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -123,19 +123,23 @@ export default {
|
|||||||
dateRangeToday: 'Today',
|
dateRangeToday: 'Today',
|
||||||
dateRange7d: '7 Days',
|
dateRange7d: '7 Days',
|
||||||
dateRange30d: '30 Days',
|
dateRange30d: '30 Days',
|
||||||
|
dateRange90d: '90 Days',
|
||||||
dateRangeCustom: 'Custom',
|
dateRangeCustom: 'Custom',
|
||||||
apply: 'Apply',
|
apply: 'Apply',
|
||||||
used: 'Used',
|
used: 'Used',
|
||||||
detailInfo: 'Detail Information',
|
detailInfo: 'Detail Information',
|
||||||
tokenStats: 'Token Statistics',
|
tokenStats: 'Token Statistics',
|
||||||
|
dailyDetail: 'Daily Detail',
|
||||||
modelStats: 'Model Usage Statistics',
|
modelStats: 'Model Usage Statistics',
|
||||||
// Table headers
|
// Table headers
|
||||||
|
date: 'Date',
|
||||||
model: 'Model',
|
model: 'Model',
|
||||||
requests: 'Requests',
|
requests: 'Requests',
|
||||||
inputTokens: 'Input Tokens',
|
inputTokens: 'Input Tokens',
|
||||||
outputTokens: 'Output Tokens',
|
outputTokens: 'Output Tokens',
|
||||||
cacheCreationTokens: 'Cache Creation',
|
cacheCreationTokens: 'Cache Creation',
|
||||||
cacheReadTokens: 'Cache Read',
|
cacheReadTokens: 'Cache Read',
|
||||||
|
cacheWriteTokens: 'Cache Write',
|
||||||
totalTokens: 'Total Tokens',
|
totalTokens: 'Total Tokens',
|
||||||
cost: 'Cost',
|
cost: 'Cost',
|
||||||
// Status
|
// Status
|
||||||
@ -179,6 +183,7 @@ export default {
|
|||||||
querySuccess: 'Query successful',
|
querySuccess: 'Query successful',
|
||||||
queryFailed: 'Query failed',
|
queryFailed: 'Query failed',
|
||||||
queryFailedRetry: 'Query failed, please try again later',
|
queryFailedRetry: 'Query failed, please try again later',
|
||||||
|
noDailyUsage: 'No daily usage data',
|
||||||
},
|
},
|
||||||
|
|
||||||
// Setup Wizard
|
// Setup Wizard
|
||||||
@ -4176,6 +4181,22 @@ export default {
|
|||||||
},
|
},
|
||||||
userPrefix: 'User #{id}',
|
userPrefix: 'User #{id}',
|
||||||
exportCsv: 'Export CSV',
|
exportCsv: 'Export CSV',
|
||||||
|
batchUpdate: 'Batch Update',
|
||||||
|
batchUpdateTitle: 'Batch Update Redeem Codes',
|
||||||
|
selectedCount: '{count} redeem code(s) selected',
|
||||||
|
clearSelection: 'Clear selection',
|
||||||
|
selectCodesFirst: 'Select redeem codes first',
|
||||||
|
noBatchFieldsSelected: 'Select at least one field to update',
|
||||||
|
batchUpdateSuccess: 'Updated {count} redeem code(s)',
|
||||||
|
failedToBatchUpdate: 'Failed to batch update redeem codes',
|
||||||
|
batchFields: {
|
||||||
|
status: 'Status',
|
||||||
|
expiresAt: 'Expires At',
|
||||||
|
notes: 'Notes',
|
||||||
|
group: 'Group'
|
||||||
|
},
|
||||||
|
batchNotesPlaceholder: 'Enter the new note, or leave blank to clear it',
|
||||||
|
clearGroup: 'Clear group',
|
||||||
deleteAllUnused: 'Delete All Unused Codes',
|
deleteAllUnused: 'Delete All Unused Codes',
|
||||||
deleteCode: 'Delete Redeem Code',
|
deleteCode: 'Delete Redeem Code',
|
||||||
deleteCodeConfirm:
|
deleteCodeConfirm:
|
||||||
@ -5515,6 +5536,9 @@ export default {
|
|||||||
antigravityUserAgentVersion: 'Antigravity UA Version',
|
antigravityUserAgentVersion: 'Antigravity UA Version',
|
||||||
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
||||||
antigravityUserAgentVersionHint: 'Leave empty to use ANTIGRAVITY_USER_AGENT_VERSION or the built-in default 1.23.2; when set, the admin setting takes precedence.',
|
antigravityUserAgentVersionHint: 'Leave empty to use ANTIGRAVITY_USER_AGENT_VERSION or the built-in default 1.23.2; when set, the admin setting takes precedence.',
|
||||||
|
openaiCodexUserAgent: 'OpenAI Codex UA',
|
||||||
|
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
||||||
|
openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.',
|
||||||
},
|
},
|
||||||
webSearchEmulation: {
|
webSearchEmulation: {
|
||||||
title: 'Web Search Emulation',
|
title: 'Web Search Emulation',
|
||||||
@ -5854,6 +5878,36 @@ export default {
|
|||||||
sending: 'Sending...',
|
sending: 'Sending...',
|
||||||
enterRecipientHint: 'Please enter a recipient email address'
|
enterRecipientHint: 'Please enter a recipient email address'
|
||||||
},
|
},
|
||||||
|
emailTemplates: {
|
||||||
|
title: 'Email Templates',
|
||||||
|
description: 'Customize notification email subjects and HTML content for each event and locale.',
|
||||||
|
event: 'Event',
|
||||||
|
locale: 'Locale',
|
||||||
|
localeEn: 'English',
|
||||||
|
localeZh: 'Chinese',
|
||||||
|
subject: 'Subject',
|
||||||
|
subjectPlaceholder: 'Enter the email subject',
|
||||||
|
html: 'HTML Template',
|
||||||
|
htmlPlaceholder: 'Edit the email HTML template',
|
||||||
|
placeholders: 'Available Placeholders',
|
||||||
|
placeholdersHelp: 'Click a placeholder to copy it. The backend replaces these values when sending emails.',
|
||||||
|
livePreview: 'Live Preview',
|
||||||
|
previewSecurityHint: 'Preview HTML is generated by the backend preview endpoint and displayed in a sandboxed iframe with scripts disabled.',
|
||||||
|
preview: 'Preview / Refresh',
|
||||||
|
previewing: 'Previewing...',
|
||||||
|
save: 'Save Template',
|
||||||
|
saving: 'Saving...',
|
||||||
|
restoreOfficial: 'Restore Official',
|
||||||
|
restoring: 'Restoring...',
|
||||||
|
restoreConfirm: 'Restore the official template for this event and locale? Your custom version will be replaced.',
|
||||||
|
restoreSuccess: 'Official template restored',
|
||||||
|
saveSuccess: 'Email template saved',
|
||||||
|
placeholderCopied: 'Placeholder copied',
|
||||||
|
validationRequired: 'Subject and HTML template are required',
|
||||||
|
empty: 'No email template events or locales are available yet.',
|
||||||
|
noPreview: 'Refresh the preview to see the rendered email subject.',
|
||||||
|
customized: 'Customized'
|
||||||
|
},
|
||||||
opsMonitoring: {
|
opsMonitoring: {
|
||||||
title: 'Ops Monitoring',
|
title: 'Ops Monitoring',
|
||||||
description: 'Enable ops monitoring for troubleshooting and health visibility',
|
description: 'Enable ops monitoring for troubleshooting and health visibility',
|
||||||
|
|||||||
@ -123,19 +123,23 @@ export default {
|
|||||||
dateRangeToday: '今日',
|
dateRangeToday: '今日',
|
||||||
dateRange7d: '7 天',
|
dateRange7d: '7 天',
|
||||||
dateRange30d: '30 天',
|
dateRange30d: '30 天',
|
||||||
|
dateRange90d: '90 天',
|
||||||
dateRangeCustom: '自定义',
|
dateRangeCustom: '自定义',
|
||||||
apply: '应用',
|
apply: '应用',
|
||||||
used: '已使用',
|
used: '已使用',
|
||||||
detailInfo: '详细信息',
|
detailInfo: '详细信息',
|
||||||
tokenStats: 'Token 统计',
|
tokenStats: 'Token 统计',
|
||||||
|
dailyDetail: '按日明细',
|
||||||
modelStats: '模型用量统计',
|
modelStats: '模型用量统计',
|
||||||
// Table headers
|
// Table headers
|
||||||
|
date: '日期',
|
||||||
model: '模型',
|
model: '模型',
|
||||||
requests: '请求数',
|
requests: '请求数',
|
||||||
inputTokens: '输入 Tokens',
|
inputTokens: '输入 Tokens',
|
||||||
outputTokens: '输出 Tokens',
|
outputTokens: '输出 Tokens',
|
||||||
cacheCreationTokens: '缓存创建',
|
cacheCreationTokens: '缓存创建',
|
||||||
cacheReadTokens: '缓存读取',
|
cacheReadTokens: '缓存读取',
|
||||||
|
cacheWriteTokens: '缓存写入',
|
||||||
totalTokens: '总 Tokens',
|
totalTokens: '总 Tokens',
|
||||||
cost: '费用',
|
cost: '费用',
|
||||||
// Status
|
// Status
|
||||||
@ -179,6 +183,7 @@ export default {
|
|||||||
querySuccess: '查询成功',
|
querySuccess: '查询成功',
|
||||||
queryFailed: '查询失败',
|
queryFailed: '查询失败',
|
||||||
queryFailedRetry: '查询失败,请稍后重试',
|
queryFailedRetry: '查询失败,请稍后重试',
|
||||||
|
noDailyUsage: '暂无按日用量数据',
|
||||||
},
|
},
|
||||||
|
|
||||||
// Setup Wizard
|
// Setup Wizard
|
||||||
@ -4310,6 +4315,22 @@ export default {
|
|||||||
used: '已使用',
|
used: '已使用',
|
||||||
searchCodes: '搜索兑换码或邮箱...',
|
searchCodes: '搜索兑换码或邮箱...',
|
||||||
exportCsv: '导出 CSV',
|
exportCsv: '导出 CSV',
|
||||||
|
batchUpdate: '批量修改',
|
||||||
|
batchUpdateTitle: '批量修改兑换码',
|
||||||
|
selectedCount: '已选择 {count} 个兑换码',
|
||||||
|
clearSelection: '清空选择',
|
||||||
|
selectCodesFirst: '请先选择兑换码',
|
||||||
|
noBatchFieldsSelected: '请至少勾选一个要修改的字段',
|
||||||
|
batchUpdateSuccess: '成功修改 {count} 个兑换码',
|
||||||
|
failedToBatchUpdate: '批量修改兑换码失败',
|
||||||
|
batchFields: {
|
||||||
|
status: '状态',
|
||||||
|
expiresAt: '过期时间',
|
||||||
|
notes: '备注',
|
||||||
|
group: '分组'
|
||||||
|
},
|
||||||
|
batchNotesPlaceholder: '输入新的备注,留空可清空备注',
|
||||||
|
clearGroup: '清空分组',
|
||||||
deleteAllUnused: '删除全部未使用',
|
deleteAllUnused: '删除全部未使用',
|
||||||
deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。',
|
deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。',
|
||||||
deleteAllUnusedConfirm: '确定要删除全部未使用的兑换码吗?此操作无法撤销。',
|
deleteAllUnusedConfirm: '确定要删除全部未使用的兑换码吗?此操作无法撤销。',
|
||||||
@ -5673,6 +5694,9 @@ export default {
|
|||||||
antigravityUserAgentVersion: 'Antigravity UA 版本',
|
antigravityUserAgentVersion: 'Antigravity UA 版本',
|
||||||
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
antigravityUserAgentVersionPlaceholder: '1.23.2',
|
||||||
antigravityUserAgentVersionHint: '留空时使用 ANTIGRAVITY_USER_AGENT_VERSION 或内置默认值 1.23.2;填写后后台设置优先。',
|
antigravityUserAgentVersionHint: '留空时使用 ANTIGRAVITY_USER_AGENT_VERSION 或内置默认值 1.23.2;填写后后台设置优先。',
|
||||||
|
openaiCodexUserAgent: 'OpenAI Codex UA',
|
||||||
|
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
||||||
|
openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器(Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。',
|
||||||
},
|
},
|
||||||
webSearchEmulation: {
|
webSearchEmulation: {
|
||||||
title: 'Web Search 模拟',
|
title: 'Web Search 模拟',
|
||||||
@ -6014,6 +6038,36 @@ export default {
|
|||||||
sending: '发送中...',
|
sending: '发送中...',
|
||||||
enterRecipientHint: '请输入收件人邮箱地址'
|
enterRecipientHint: '请输入收件人邮箱地址'
|
||||||
},
|
},
|
||||||
|
emailTemplates: {
|
||||||
|
title: '邮件模板',
|
||||||
|
description: '按事件和语言自定义通知邮件主题与 HTML 内容。',
|
||||||
|
event: '事件',
|
||||||
|
locale: '语言',
|
||||||
|
localeEn: '英文',
|
||||||
|
localeZh: '中文',
|
||||||
|
subject: '主题',
|
||||||
|
subjectPlaceholder: '输入邮件主题',
|
||||||
|
html: 'HTML 模板',
|
||||||
|
htmlPlaceholder: '编辑邮件 HTML 模板',
|
||||||
|
placeholders: '可用占位符',
|
||||||
|
placeholdersHelp: '点击占位符可复制。后端发送邮件时会替换这些值。',
|
||||||
|
livePreview: '实时预览',
|
||||||
|
previewSecurityHint: '预览 HTML 由后端预览接口生成,并在禁用脚本的沙盒 iframe 中展示。',
|
||||||
|
preview: '预览 / 刷新',
|
||||||
|
previewing: '预览中...',
|
||||||
|
save: '保存模板',
|
||||||
|
saving: '保存中...',
|
||||||
|
restoreOfficial: '恢复官方模板',
|
||||||
|
restoring: '恢复中...',
|
||||||
|
restoreConfirm: '确定恢复此事件和语言的官方模板吗?当前自定义版本将被替换。',
|
||||||
|
restoreSuccess: '已恢复官方模板',
|
||||||
|
saveSuccess: '邮件模板已保存',
|
||||||
|
placeholderCopied: '占位符已复制',
|
||||||
|
validationRequired: '主题和 HTML 模板不能为空',
|
||||||
|
empty: '暂无可用的邮件模板事件或语言。',
|
||||||
|
noPreview: '刷新预览后查看渲染后的邮件主题。',
|
||||||
|
customized: '已自定义'
|
||||||
|
},
|
||||||
opsMonitoring: {
|
opsMonitoring: {
|
||||||
title: '运维监控',
|
title: '运维监控',
|
||||||
description: '启用运维监控模块,用于排障与健康可视化',
|
description: '启用运维监控模块,用于排障与健康可视化',
|
||||||
|
|||||||
@ -289,6 +289,62 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Daily Usage Table -->
|
||||||
|
<div
|
||||||
|
v-if="showDailyUsage"
|
||||||
|
class="fade-up fade-up-delay-4 rounded-2xl border border-gray-200 bg-white/90 backdrop-blur-sm overflow-hidden dark:border-dark-700 dark:bg-dark-900/90"
|
||||||
|
>
|
||||||
|
<div class="flex flex-col gap-3 px-8 py-5 border-b border-gray-200 dark:border-dark-700 sm:flex-row sm:items-center sm:justify-between">
|
||||||
|
<h3 class="text-sm font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.dailyDetail') }}</h3>
|
||||||
|
<div class="inline-flex rounded-lg border border-gray-200 bg-white p-0.5 dark:border-dark-700 dark:bg-dark-950">
|
||||||
|
<button
|
||||||
|
v-for="option in dailyUsageOptions"
|
||||||
|
:key="option.value"
|
||||||
|
@click="setDailyUsageDays(option.value)"
|
||||||
|
class="min-w-12 rounded-md px-3 py-1.5 text-xs font-medium transition-colors"
|
||||||
|
:class="dailyUsageDays === option.value
|
||||||
|
? 'bg-primary-500 text-white'
|
||||||
|
: 'text-gray-600 hover:bg-gray-100 dark:text-dark-300 dark:hover:bg-dark-800'"
|
||||||
|
>
|
||||||
|
{{ option.label }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div v-if="dailyUsageRows.length > 0" class="overflow-x-auto">
|
||||||
|
<table class="w-full">
|
||||||
|
<thead>
|
||||||
|
<tr class="border-b border-gray-200 bg-gray-50 dark:border-dark-700 dark:bg-dark-950">
|
||||||
|
<th class="px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.date') }}</th>
|
||||||
|
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.requests') }}</th>
|
||||||
|
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.inputTokens') }}</th>
|
||||||
|
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.outputTokens') }}</th>
|
||||||
|
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheReadTokens') }}</th>
|
||||||
|
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cacheWriteTokens') }}</th>
|
||||||
|
<th class="px-4 py-3 text-right text-xs font-semibold uppercase tracking-wider text-gray-500 dark:text-dark-400">{{ t('keyUsage.cost') }}</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr
|
||||||
|
v-for="row in dailyUsageRows"
|
||||||
|
:key="row.date"
|
||||||
|
class="border-b border-gray-100 last:border-b-0 dark:border-dark-800"
|
||||||
|
>
|
||||||
|
<td class="px-4 py-3 text-sm font-medium whitespace-nowrap text-gray-900 dark:text-white">{{ row.date }}</td>
|
||||||
|
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.requests) }}</td>
|
||||||
|
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.input_tokens) }}</td>
|
||||||
|
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.output_tokens) }}</td>
|
||||||
|
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.cache_read_tokens) }}</td>
|
||||||
|
<td class="px-4 py-3 text-sm tabular-nums text-right text-gray-700 dark:text-dark-200">{{ fmtNum(row.cache_write_tokens) }}</td>
|
||||||
|
<td class="px-4 py-3 text-sm tabular-nums text-right font-medium text-gray-900 dark:text-white">{{ usd(row.actual_cost != null ? row.actual_cost : row.cost) }}</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
<div v-else class="px-8 py-8 text-center text-sm text-gray-500 dark:text-dark-400">
|
||||||
|
{{ t('keyUsage.noDailyUsage') }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Model Stats Table -->
|
<!-- Model Stats Table -->
|
||||||
<div
|
<div
|
||||||
v-if="modelStats.length > 0"
|
v-if="modelStats.length > 0"
|
||||||
@ -408,6 +464,7 @@ type DateRangeKey = 'today' | '7d' | '30d' | 'custom'
|
|||||||
const currentRange = ref<DateRangeKey>('today')
|
const currentRange = ref<DateRangeKey>('today')
|
||||||
const customStartDate = ref('')
|
const customStartDate = ref('')
|
||||||
const customEndDate = ref('')
|
const customEndDate = ref('')
|
||||||
|
const dailyUsageDays = ref<7 | 30 | 90>(30)
|
||||||
|
|
||||||
const dateRanges = computed(() => [
|
const dateRanges = computed(() => [
|
||||||
{ key: 'today' as const, label: t('keyUsage.dateRangeToday') },
|
{ key: 'today' as const, label: t('keyUsage.dateRangeToday') },
|
||||||
@ -416,6 +473,12 @@ const dateRanges = computed(() => [
|
|||||||
{ key: 'custom' as const, label: t('keyUsage.dateRangeCustom') },
|
{ key: 'custom' as const, label: t('keyUsage.dateRangeCustom') },
|
||||||
])
|
])
|
||||||
|
|
||||||
|
const dailyUsageOptions = computed(() => [
|
||||||
|
{ value: 7 as const, label: t('keyUsage.dateRange7d') },
|
||||||
|
{ value: 30 as const, label: t('keyUsage.dateRange30d') },
|
||||||
|
{ value: 90 as const, label: t('keyUsage.dateRange90d') },
|
||||||
|
])
|
||||||
|
|
||||||
function setDateRange(key: DateRangeKey) {
|
function setDateRange(key: DateRangeKey) {
|
||||||
currentRange.value = key
|
currentRange.value = key
|
||||||
if (key !== 'custom') {
|
if (key !== 'custom') {
|
||||||
@ -426,23 +489,36 @@ function setDateRange(key: DateRangeKey) {
|
|||||||
function getDateParams(): string {
|
function getDateParams(): string {
|
||||||
const now = new Date()
|
const now = new Date()
|
||||||
const fmt = (d: Date) => d.toISOString().split('T')[0]
|
const fmt = (d: Date) => d.toISOString().split('T')[0]
|
||||||
|
const params = new URLSearchParams()
|
||||||
|
|
||||||
if (currentRange.value === 'custom') {
|
if (currentRange.value === 'custom') {
|
||||||
if (customStartDate.value && customEndDate.value) {
|
if (customStartDate.value && customEndDate.value) {
|
||||||
return `start_date=${customStartDate.value}&end_date=${customEndDate.value}`
|
params.set('start_date', customStartDate.value)
|
||||||
|
params.set('end_date', customEndDate.value)
|
||||||
}
|
}
|
||||||
return ''
|
} else {
|
||||||
|
const end = fmt(now)
|
||||||
|
let start: string
|
||||||
|
switch (currentRange.value) {
|
||||||
|
case 'today': start = end; break
|
||||||
|
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
|
||||||
|
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
|
||||||
|
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
|
||||||
|
}
|
||||||
|
params.set('start_date', start)
|
||||||
|
params.set('end_date', end)
|
||||||
}
|
}
|
||||||
|
params.set('days', String(dailyUsageDays.value))
|
||||||
|
params.set('timezone', getBrowserTimezone())
|
||||||
|
return params.toString()
|
||||||
|
}
|
||||||
|
|
||||||
const end = fmt(now)
|
function setDailyUsageDays(days: 7 | 30 | 90) {
|
||||||
let start: string
|
if (dailyUsageDays.value === days) return
|
||||||
switch (currentRange.value) {
|
dailyUsageDays.value = days
|
||||||
case 'today': start = end; break
|
if (resultData.value && apiKey.value.trim()) {
|
||||||
case '7d': start = fmt(new Date(now.getTime() - 7 * 86400000)); break
|
queryKey()
|
||||||
case '30d': start = fmt(new Date(now.getTime() - 30 * 86400000)); break
|
|
||||||
default: start = fmt(new Date(now.getTime() - 30 * 86400000))
|
|
||||||
}
|
}
|
||||||
return `start_date=${start}&end_date=${end}`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== Ring Animation ====================
|
// ==================== Ring Animation ====================
|
||||||
@ -731,6 +807,24 @@ const usageStatCells = computed<StatCell[]>(() => {
|
|||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
const modelStats = computed<any[]>(() => resultData.value?.model_stats || [])
|
const modelStats = computed<any[]>(() => resultData.value?.model_stats || [])
|
||||||
|
|
||||||
|
interface DailyUsageRow {
|
||||||
|
date: string
|
||||||
|
requests: number
|
||||||
|
input_tokens: number
|
||||||
|
output_tokens: number
|
||||||
|
cache_read_tokens: number
|
||||||
|
cache_write_tokens: number
|
||||||
|
cost: number
|
||||||
|
actual_cost?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const dailyUsageRows = computed<DailyUsageRow[]>(() => {
|
||||||
|
const rows = resultData.value?.daily_usage
|
||||||
|
return Array.isArray(rows) ? rows : []
|
||||||
|
})
|
||||||
|
|
||||||
|
const showDailyUsage = computed(() => Boolean(resultData.value && Array.isArray(resultData.value.daily_usage)))
|
||||||
|
|
||||||
// ==================== Utility Functions ====================
|
// ==================== Utility Functions ====================
|
||||||
|
|
||||||
function usd(value: number | null | undefined): string {
|
function usd(value: number | null | undefined): string {
|
||||||
@ -750,6 +844,14 @@ function formatDate(iso: string | null | undefined): string {
|
|||||||
return d.toLocaleDateString(loc, { year: 'numeric', month: 'long', day: 'numeric' })
|
return d.toLocaleDateString(loc, { year: 'numeric', month: 'long', day: 'numeric' })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getBrowserTimezone(): string {
|
||||||
|
try {
|
||||||
|
return Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC'
|
||||||
|
} catch {
|
||||||
|
return 'UTC'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== API Query ====================
|
// ==================== API Query ====================
|
||||||
|
|
||||||
async function fetchUsage(key: string) {
|
async function fetchUsage(key: string) {
|
||||||
|
|||||||
208
frontend/src/views/__tests__/KeyUsageView.spec.ts
Normal file
208
frontend/src/views/__tests__/KeyUsageView.spec.ts
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import { describe, expect, it, beforeEach, afterEach, vi } from 'vitest'
|
||||||
|
import { flushPromises, mount } from '@vue/test-utils'
|
||||||
|
import { nextTick } from 'vue'
|
||||||
|
|
||||||
|
import KeyUsageView from '../KeyUsageView.vue'
|
||||||
|
|
||||||
|
const { showInfo, showSuccess, showError, fetchPublicSettings } = vi.hoisted(() => ({
|
||||||
|
showInfo: vi.fn(),
|
||||||
|
showSuccess: vi.fn(),
|
||||||
|
showError: vi.fn(),
|
||||||
|
fetchPublicSettings: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const messages: Record<string, string> = {
|
||||||
|
'keyUsage.title': 'API Key Usage',
|
||||||
|
'keyUsage.subtitle': 'Usage status',
|
||||||
|
'keyUsage.placeholder': 'sk-test',
|
||||||
|
'keyUsage.query': 'Query',
|
||||||
|
'keyUsage.querying': 'Querying...',
|
||||||
|
'keyUsage.privacyNote': 'Privacy note',
|
||||||
|
'keyUsage.dateRange': 'Date Range:',
|
||||||
|
'keyUsage.dateRangeToday': 'Today',
|
||||||
|
'keyUsage.dateRange7d': '7 Days',
|
||||||
|
'keyUsage.dateRange30d': '30 Days',
|
||||||
|
'keyUsage.dateRange90d': '90 Days',
|
||||||
|
'keyUsage.dateRangeCustom': 'Custom',
|
||||||
|
'keyUsage.apply': 'Apply',
|
||||||
|
'keyUsage.used': 'Used',
|
||||||
|
'keyUsage.detailInfo': 'Detail Information',
|
||||||
|
'keyUsage.tokenStats': 'Token Statistics',
|
||||||
|
'keyUsage.dailyDetail': 'Daily Detail',
|
||||||
|
'keyUsage.date': 'Date',
|
||||||
|
'keyUsage.requests': 'Requests',
|
||||||
|
'keyUsage.inputTokens': 'Input Tokens',
|
||||||
|
'keyUsage.outputTokens': 'Output Tokens',
|
||||||
|
'keyUsage.cacheReadTokens': 'Cache Read',
|
||||||
|
'keyUsage.cacheWriteTokens': 'Cache Write',
|
||||||
|
'keyUsage.cost': 'Cost',
|
||||||
|
'keyUsage.quotaMode': 'Key Quota Mode',
|
||||||
|
'keyUsage.walletBalance': 'Wallet Balance',
|
||||||
|
'keyUsage.totalQuota': 'Total Quota',
|
||||||
|
'keyUsage.limit5h': '5-Hour Limit',
|
||||||
|
'keyUsage.limitDaily': 'Daily Limit',
|
||||||
|
'keyUsage.limit7d': '7-Day Limit',
|
||||||
|
'keyUsage.limitWeekly': 'Weekly Limit',
|
||||||
|
'keyUsage.limitMonthly': 'Monthly Limit',
|
||||||
|
'keyUsage.remainingQuota': 'Remaining Quota',
|
||||||
|
'keyUsage.usedQuota': 'Used Quota',
|
||||||
|
'keyUsage.subscriptionType': 'Subscription Type',
|
||||||
|
'keyUsage.todayRequests': 'Today Requests',
|
||||||
|
'keyUsage.todayInputTokens': 'Today Input',
|
||||||
|
'keyUsage.todayOutputTokens': 'Today Output',
|
||||||
|
'keyUsage.todayTokens': 'Today Tokens',
|
||||||
|
'keyUsage.todayCacheCreation': 'Today Cache Creation',
|
||||||
|
'keyUsage.todayCacheRead': 'Today Cache Read',
|
||||||
|
'keyUsage.todayCost': 'Today Cost',
|
||||||
|
'keyUsage.rpmTpm': 'RPM / TPM',
|
||||||
|
'keyUsage.totalRequests': 'Total Requests',
|
||||||
|
'keyUsage.totalInputTokens': 'Total Input',
|
||||||
|
'keyUsage.totalOutputTokens': 'Total Output',
|
||||||
|
'keyUsage.totalTokensLabel': 'Total Tokens',
|
||||||
|
'keyUsage.totalCacheCreation': 'Total Cache Creation',
|
||||||
|
'keyUsage.totalCacheRead': 'Total Cache Read',
|
||||||
|
'keyUsage.totalCost': 'Total Cost',
|
||||||
|
'keyUsage.avgDuration': 'Avg Duration',
|
||||||
|
'keyUsage.querySuccess': 'Query successful',
|
||||||
|
'keyUsage.queryFailed': 'Query failed',
|
||||||
|
'keyUsage.queryFailedRetry': 'Query failed, please try again later',
|
||||||
|
'home.viewDocs': 'Docs',
|
||||||
|
'home.switchToLight': 'Light',
|
||||||
|
'home.switchToDark': 'Dark',
|
||||||
|
'home.footer.allRightsReserved': 'All rights reserved.',
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', async () => {
|
||||||
|
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => messages[key] ?? key,
|
||||||
|
locale: { value: 'en' },
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('@/stores', () => ({
|
||||||
|
useAppStore: () => ({
|
||||||
|
cachedPublicSettings: null,
|
||||||
|
siteName: 'Sub2API',
|
||||||
|
siteLogo: '',
|
||||||
|
docUrl: '',
|
||||||
|
publicSettingsLoaded: true,
|
||||||
|
fetchPublicSettings,
|
||||||
|
showInfo,
|
||||||
|
showSuccess,
|
||||||
|
showError,
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('KeyUsageView daily detail', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
showInfo.mockReset()
|
||||||
|
showSuccess.mockReset()
|
||||||
|
showError.mockReset()
|
||||||
|
fetchPublicSettings.mockReset()
|
||||||
|
localStorage.clear()
|
||||||
|
|
||||||
|
Object.defineProperty(window, 'matchMedia', {
|
||||||
|
configurable: true,
|
||||||
|
value: vi.fn().mockReturnValue({ matches: false }),
|
||||||
|
})
|
||||||
|
vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => window.setTimeout(() => cb(0), 0))
|
||||||
|
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: async () => ({
|
||||||
|
mode: 'quota_limited',
|
||||||
|
isValid: true,
|
||||||
|
status: 'active',
|
||||||
|
quota: {
|
||||||
|
limit: 10,
|
||||||
|
used: 1,
|
||||||
|
remaining: 9,
|
||||||
|
unit: 'USD',
|
||||||
|
},
|
||||||
|
usage: {
|
||||||
|
today: {
|
||||||
|
requests: 1,
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 20,
|
||||||
|
cache_creation_tokens: 0,
|
||||||
|
cache_read_tokens: 0,
|
||||||
|
total_tokens: 30,
|
||||||
|
actual_cost: 0.01,
|
||||||
|
},
|
||||||
|
total: {
|
||||||
|
requests: 12,
|
||||||
|
input_tokens: 100,
|
||||||
|
output_tokens: 200,
|
||||||
|
cache_creation_tokens: 10,
|
||||||
|
cache_read_tokens: 30,
|
||||||
|
total_tokens: 340,
|
||||||
|
actual_cost: 0.12,
|
||||||
|
},
|
||||||
|
rpm: 0,
|
||||||
|
tpm: 0,
|
||||||
|
},
|
||||||
|
daily_usage: [
|
||||||
|
{
|
||||||
|
date: '2026-05-19',
|
||||||
|
requests: 12,
|
||||||
|
input_tokens: 100,
|
||||||
|
output_tokens: 200,
|
||||||
|
cache_read_tokens: 30,
|
||||||
|
cache_write_tokens: 10,
|
||||||
|
total_tokens: 340,
|
||||||
|
cost: 0.15,
|
||||||
|
actual_cost: 0.12,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.unstubAllGlobals()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders daily usage detail rows after a successful query', async () => {
|
||||||
|
const wrapper = mount(KeyUsageView, {
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
RouterLink: { template: '<a><slot /></a>' },
|
||||||
|
LocaleSwitcher: true,
|
||||||
|
Icon: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
await wrapper.find('input').setValue('sk-test-key')
|
||||||
|
await wrapper.find('input').trigger('keydown.enter')
|
||||||
|
await flushPromises()
|
||||||
|
await nextTick()
|
||||||
|
|
||||||
|
const fetchMock = vi.mocked(fetch)
|
||||||
|
expect(fetchMock).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('/v1/usage?'),
|
||||||
|
expect.objectContaining({
|
||||||
|
headers: { Authorization: 'Bearer sk-test-key' },
|
||||||
|
})
|
||||||
|
)
|
||||||
|
expect(String(fetchMock.mock.calls[0][0])).toContain('days=30')
|
||||||
|
|
||||||
|
const text = wrapper.text()
|
||||||
|
expect(text).toContain('Daily Detail')
|
||||||
|
expect(text).toContain('Date')
|
||||||
|
expect(text).toContain('Cache Read')
|
||||||
|
expect(text).toContain('Cache Write')
|
||||||
|
expect(text).toContain('2026-05-19')
|
||||||
|
expect(text).toContain('12')
|
||||||
|
expect(text).toContain('100')
|
||||||
|
expect(text).toContain('200')
|
||||||
|
expect(text).toContain('30')
|
||||||
|
expect(text).toContain('10')
|
||||||
|
expect(text).toContain('$0.12')
|
||||||
|
|
||||||
|
wrapper.unmount()
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -3769,6 +3769,36 @@
|
|||||||
}}
|
}}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- OpenAI Codex UA -->
|
||||||
|
<div>
|
||||||
|
<label
|
||||||
|
class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
|
{{
|
||||||
|
t(
|
||||||
|
"admin.settings.gatewayForwarding.openaiCodexUserAgent",
|
||||||
|
)
|
||||||
|
}}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="form.openai_codex_user_agent"
|
||||||
|
type="text"
|
||||||
|
class="input w-full font-mono text-sm"
|
||||||
|
:placeholder="
|
||||||
|
t(
|
||||||
|
'admin.settings.gatewayForwarding.openaiCodexUserAgentPlaceholder',
|
||||||
|
)
|
||||||
|
"
|
||||||
|
/>
|
||||||
|
<p class="mt-1.5 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{
|
||||||
|
t(
|
||||||
|
"admin.settings.gatewayForwarding.openaiCodexUserAgentHint",
|
||||||
|
)
|
||||||
|
}}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<!-- Web Search Emulation -->
|
<!-- Web Search Emulation -->
|
||||||
@ -6225,6 +6255,9 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<EmailTemplateEditor />
|
||||||
|
|
||||||
<!-- Balance Low Notification -->
|
<!-- Balance Low Notification -->
|
||||||
<div class="card">
|
<div class="card">
|
||||||
<div
|
<div
|
||||||
@ -6482,6 +6515,7 @@ import Toggle from "@/components/common/Toggle.vue";
|
|||||||
import ProxySelector from "@/components/common/ProxySelector.vue";
|
import ProxySelector from "@/components/common/ProxySelector.vue";
|
||||||
import ImageUpload from "@/components/common/ImageUpload.vue";
|
import ImageUpload from "@/components/common/ImageUpload.vue";
|
||||||
import BackupSettings from "@/views/admin/BackupView.vue";
|
import BackupSettings from "@/views/admin/BackupView.vue";
|
||||||
|
import EmailTemplateEditor from "@/views/admin/settings/EmailTemplateEditor.vue";
|
||||||
import { useClipboard } from "@/composables/useClipboard";
|
import { useClipboard } from "@/composables/useClipboard";
|
||||||
import { affiliatesAPI, type AffiliateAdminEntry, type SimpleUser as AffiliateSimpleUser } from "@/api/admin/affiliates";
|
import { affiliatesAPI, type AffiliateAdminEntry, type SimpleUser as AffiliateSimpleUser } from "@/api/admin/affiliates";
|
||||||
import { extractApiErrorMessage, extractI18nErrorMessage } from "@/utils/apiError";
|
import { extractApiErrorMessage, extractI18nErrorMessage } from "@/utils/apiError";
|
||||||
@ -6943,6 +6977,7 @@ const form = reactive<SettingsForm>({
|
|||||||
enable_anthropic_cache_ttl_1h_injection: false,
|
enable_anthropic_cache_ttl_1h_injection: false,
|
||||||
rewrite_message_cache_control: false,
|
rewrite_message_cache_control: false,
|
||||||
antigravity_user_agent_version: "",
|
antigravity_user_agent_version: "",
|
||||||
|
openai_codex_user_agent: "",
|
||||||
// Balance & quota notification
|
// Balance & quota notification
|
||||||
balance_low_notify_enabled: false,
|
balance_low_notify_enabled: false,
|
||||||
balance_low_notify_threshold: 0,
|
balance_low_notify_threshold: 0,
|
||||||
@ -8044,6 +8079,8 @@ async function saveSettings() {
|
|||||||
rewrite_message_cache_control: form.rewrite_message_cache_control,
|
rewrite_message_cache_control: form.rewrite_message_cache_control,
|
||||||
antigravity_user_agent_version:
|
antigravity_user_agent_version:
|
||||||
form.antigravity_user_agent_version?.trim() || "",
|
form.antigravity_user_agent_version?.trim() || "",
|
||||||
|
openai_codex_user_agent:
|
||||||
|
form.openai_codex_user_agent?.trim() || "",
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
payment_enabled: form.payment_enabled,
|
payment_enabled: form.payment_enabled,
|
||||||
risk_control_enabled: form.risk_control_enabled,
|
risk_control_enabled: form.risk_control_enabled,
|
||||||
|
|||||||
@ -371,6 +371,7 @@ const baseSettingsResponse = {
|
|||||||
enable_anthropic_cache_ttl_1h_injection: false,
|
enable_anthropic_cache_ttl_1h_injection: false,
|
||||||
rewrite_message_cache_control: false,
|
rewrite_message_cache_control: false,
|
||||||
antigravity_user_agent_version: "",
|
antigravity_user_agent_version: "",
|
||||||
|
openai_codex_user_agent: "",
|
||||||
payment_enabled: true,
|
payment_enabled: true,
|
||||||
payment_min_amount: 1,
|
payment_min_amount: 1,
|
||||||
payment_max_amount: 10000,
|
payment_max_amount: 10000,
|
||||||
|
|||||||
483
frontend/src/views/admin/settings/EmailTemplateEditor.vue
Normal file
483
frontend/src/views/admin/settings/EmailTemplateEditor.vue
Normal file
@ -0,0 +1,483 @@
|
|||||||
|
<template>
|
||||||
|
<div class="card">
|
||||||
|
<div
|
||||||
|
class="flex flex-col gap-3 border-b border-gray-100 px-6 py-4 dark:border-dark-700 lg:flex-row lg:items-start lg:justify-between"
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||||
|
{{ t("admin.settings.emailTemplates.title") }}
|
||||||
|
</h2>
|
||||||
|
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t("admin.settings.emailTemplates.description") }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="flex flex-wrap gap-2">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-secondary btn-sm"
|
||||||
|
:disabled="loadingTemplate || previewing || !canPreview"
|
||||||
|
@click="refreshPreview"
|
||||||
|
>
|
||||||
|
{{ previewing ? t("admin.settings.emailTemplates.previewing") : t("admin.settings.emailTemplates.preview") }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-secondary btn-sm"
|
||||||
|
:disabled="loadingTemplate || restoring || !selectedEvent || !selectedLocale"
|
||||||
|
@click="restoreOfficial"
|
||||||
|
>
|
||||||
|
{{ restoring ? t("admin.settings.emailTemplates.restoring") : t("admin.settings.emailTemplates.restoreOfficial") }}
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="btn btn-primary btn-sm"
|
||||||
|
:disabled="loadingTemplate || saving || !canSave"
|
||||||
|
@click="saveTemplate"
|
||||||
|
>
|
||||||
|
{{ saving ? t("admin.settings.emailTemplates.saving") : t("admin.settings.emailTemplates.save") }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="space-y-6 p-6">
|
||||||
|
<div
|
||||||
|
v-if="loadingList"
|
||||||
|
class="flex items-center gap-2 text-sm text-gray-500 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
class="h-4 w-4 animate-spin rounded-full border-b-2 border-primary-600"
|
||||||
|
></span>
|
||||||
|
{{ t("common.loading") }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template v-else>
|
||||||
|
<div class="grid grid-cols-1 gap-4 md:grid-cols-2">
|
||||||
|
<div>
|
||||||
|
<label class="input-label" for="email-template-event">
|
||||||
|
{{ t("admin.settings.emailTemplates.event") }}
|
||||||
|
</label>
|
||||||
|
<select
|
||||||
|
id="email-template-event"
|
||||||
|
v-model="selectedEvent"
|
||||||
|
class="input"
|
||||||
|
:disabled="loadingTemplate || eventOptions.length === 0"
|
||||||
|
>
|
||||||
|
<option
|
||||||
|
v-for="option in eventOptions"
|
||||||
|
:key="option.value"
|
||||||
|
:value="option.value"
|
||||||
|
>
|
||||||
|
{{ option.label || option.value }}
|
||||||
|
</option>
|
||||||
|
</select>
|
||||||
|
<p v-if="selectedEventDescription" class="input-hint">
|
||||||
|
{{ selectedEventDescription }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label" for="email-template-locale">
|
||||||
|
{{ t("admin.settings.emailTemplates.locale") }}
|
||||||
|
</label>
|
||||||
|
<select
|
||||||
|
id="email-template-locale"
|
||||||
|
v-model="selectedLocale"
|
||||||
|
class="input"
|
||||||
|
:disabled="loadingTemplate || localeOptions.length === 0"
|
||||||
|
>
|
||||||
|
<option
|
||||||
|
v-for="localeOption in localeOptions"
|
||||||
|
:key="localeOption"
|
||||||
|
:value="localeOption"
|
||||||
|
>
|
||||||
|
{{ formatLocale(localeOption) }}
|
||||||
|
</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="!eventOptions.length || !localeOptions.length"
|
||||||
|
class="rounded-lg border border-amber-200 bg-amber-50 p-4 text-sm text-amber-700 dark:border-amber-800 dark:bg-amber-900/20 dark:text-amber-300"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.emailTemplates.empty") }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else class="grid grid-cols-1 gap-6 xl:grid-cols-2">
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label class="input-label" for="email-template-subject">
|
||||||
|
{{ t("admin.settings.emailTemplates.subject") }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
id="email-template-subject"
|
||||||
|
v-model="subject"
|
||||||
|
type="text"
|
||||||
|
class="input"
|
||||||
|
:disabled="loadingTemplate"
|
||||||
|
:placeholder="t('admin.settings.emailTemplates.subjectPlaceholder')"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label class="input-label" for="email-template-html">
|
||||||
|
{{ t("admin.settings.emailTemplates.html") }}
|
||||||
|
</label>
|
||||||
|
<textarea
|
||||||
|
id="email-template-html"
|
||||||
|
v-model="html"
|
||||||
|
rows="18"
|
||||||
|
class="input min-h-[28rem] resize-y font-mono text-sm leading-6"
|
||||||
|
:disabled="loadingTemplate"
|
||||||
|
:placeholder="t('admin.settings.emailTemplates.htmlPlaceholder')"
|
||||||
|
></textarea>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
class="rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-dark-700 dark:bg-dark-800/60"
|
||||||
|
>
|
||||||
|
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||||
|
{{ t("admin.settings.emailTemplates.placeholders") }}
|
||||||
|
</div>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t("admin.settings.emailTemplates.placeholdersHelp") }}
|
||||||
|
</p>
|
||||||
|
<div class="mt-3 flex flex-wrap gap-2">
|
||||||
|
<button
|
||||||
|
v-for="placeholder in placeholderList"
|
||||||
|
:key="placeholder"
|
||||||
|
type="button"
|
||||||
|
class="rounded-full border border-gray-200 bg-white px-3 py-1 font-mono text-xs text-gray-700 transition-colors hover:border-primary-300 hover:text-primary-600 dark:border-dark-600 dark:bg-dark-700 dark:text-gray-200 dark:hover:border-primary-500 dark:hover:text-primary-300"
|
||||||
|
@click="copyPlaceholder(placeholder)"
|
||||||
|
>
|
||||||
|
{{ placeholder }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div
|
||||||
|
class="rounded-lg border border-gray-200 bg-white dark:border-dark-700 dark:bg-dark-800"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
class="flex items-center justify-between border-b border-gray-100 px-4 py-3 dark:border-dark-700"
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<div class="text-sm font-medium text-gray-900 dark:text-white">
|
||||||
|
{{ t("admin.settings.emailTemplates.livePreview") }}
|
||||||
|
</div>
|
||||||
|
<div class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ previewSubject || t("admin.settings.emailTemplates.noPreview") }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<span
|
||||||
|
v-if="isCustomTemplate"
|
||||||
|
class="rounded-full bg-primary-50 px-2.5 py-1 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
|
||||||
|
>
|
||||||
|
{{ t("admin.settings.emailTemplates.customized") }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div class="bg-gray-100 p-3 dark:bg-dark-900">
|
||||||
|
<iframe
|
||||||
|
class="h-[36rem] w-full rounded-md border border-gray-200 bg-white dark:border-dark-700"
|
||||||
|
sandbox=""
|
||||||
|
:srcdoc="previewHtml"
|
||||||
|
:title="t('admin.settings.emailTemplates.livePreview')"
|
||||||
|
></iframe>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t("admin.settings.emailTemplates.previewSecurityHint") }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { computed, onMounted, ref, watch } from "vue";
|
||||||
|
import { useI18n } from "vue-i18n";
|
||||||
|
import { adminAPI } from "@/api";
|
||||||
|
import type {
|
||||||
|
EmailTemplateEventOption,
|
||||||
|
EmailTemplateOption,
|
||||||
|
} from "@/api/admin/settings";
|
||||||
|
import { useAppStore } from "@/stores";
|
||||||
|
import { extractApiErrorMessage } from "@/utils/apiError";
|
||||||
|
|
||||||
|
const { t, locale } = useI18n();
|
||||||
|
const appStore = useAppStore();
|
||||||
|
|
||||||
|
const fallbackPlaceholders = [
|
||||||
|
"{{site_name}}",
|
||||||
|
"{{recipient_name}}",
|
||||||
|
"{{recipient_email}}",
|
||||||
|
"{{verification_code}}",
|
||||||
|
"{{expires_in_minutes}}",
|
||||||
|
"{{reset_url}}",
|
||||||
|
"{{subscription_group}}",
|
||||||
|
"{{subscription_days}}",
|
||||||
|
"{{expiry_time}}",
|
||||||
|
"{{days_remaining}}",
|
||||||
|
"{{current_balance}}",
|
||||||
|
"{{threshold}}",
|
||||||
|
"{{recharge_url}}",
|
||||||
|
"{{recharge_amount}}",
|
||||||
|
"{{order_id}}",
|
||||||
|
"{{unsubscribe_url}}",
|
||||||
|
"{{account_id}}",
|
||||||
|
"{{account_name}}",
|
||||||
|
"{{platform}}",
|
||||||
|
"{{quota_dimension}}",
|
||||||
|
"{{quota_used}}",
|
||||||
|
"{{quota_limit}}",
|
||||||
|
"{{quota_remaining}}",
|
||||||
|
"{{quota_threshold}}",
|
||||||
|
"{{triggered_at}}",
|
||||||
|
"{{group_name}}",
|
||||||
|
"{{moderation_category}}",
|
||||||
|
"{{moderation_score}}",
|
||||||
|
"{{violation_count}}",
|
||||||
|
"{{ban_threshold}}",
|
||||||
|
"{{rule_name}}",
|
||||||
|
"{{severity}}",
|
||||||
|
"{{alert_status}}",
|
||||||
|
"{{metric_type}}",
|
||||||
|
"{{operator}}",
|
||||||
|
"{{metric_value}}",
|
||||||
|
"{{threshold_value}}",
|
||||||
|
"{{alert_description}}",
|
||||||
|
"{{report_name}}",
|
||||||
|
"{{report_type}}",
|
||||||
|
"{{report_start_time}}",
|
||||||
|
"{{report_end_time}}",
|
||||||
|
"{{report_html}}",
|
||||||
|
];
|
||||||
|
|
||||||
|
const loadingList = ref(true);
|
||||||
|
const loadingTemplate = ref(false);
|
||||||
|
const saving = ref(false);
|
||||||
|
const previewing = ref(false);
|
||||||
|
const restoring = ref(false);
|
||||||
|
const eventOptions = ref<EmailTemplateOption[]>([]);
|
||||||
|
const localeOptions = ref<string[]>([]);
|
||||||
|
const selectedEvent = ref("");
|
||||||
|
const selectedLocale = ref("");
|
||||||
|
const subject = ref("");
|
||||||
|
const html = ref("");
|
||||||
|
const isCustomTemplate = ref(false);
|
||||||
|
const placeholders = ref<string[]>([]);
|
||||||
|
const previewSubject = ref("");
|
||||||
|
const previewHtml = ref("");
|
||||||
|
const initializingSelection = ref(false);
|
||||||
|
|
||||||
|
function normalizeEventOption(option: EmailTemplateEventOption): EmailTemplateOption {
|
||||||
|
if (typeof option === "string") {
|
||||||
|
return { value: option };
|
||||||
|
}
|
||||||
|
return option;
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectedEventDescription = computed(() => {
|
||||||
|
return (
|
||||||
|
eventOptions.value.find((option) => option.value === selectedEvent.value)
|
||||||
|
?.description || ""
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const placeholderList = computed(() => {
|
||||||
|
const combined = [...placeholders.value, ...fallbackPlaceholders];
|
||||||
|
return Array.from(
|
||||||
|
new Set(
|
||||||
|
combined
|
||||||
|
.map((item) => formatPlaceholder(item))
|
||||||
|
.filter((item) => item.length > 0),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
function formatPlaceholder(placeholder: string): string {
|
||||||
|
const trimmed = placeholder.trim();
|
||||||
|
if (!trimmed) return "";
|
||||||
|
if (trimmed.startsWith("{{") && trimmed.endsWith("}}")) return trimmed;
|
||||||
|
return `{{${trimmed}}}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const canSave = computed(
|
||||||
|
() =>
|
||||||
|
Boolean(selectedEvent.value && selectedLocale.value) &&
|
||||||
|
subject.value.trim().length > 0 &&
|
||||||
|
html.value.trim().length > 0,
|
||||||
|
);
|
||||||
|
|
||||||
|
const canPreview = computed(
|
||||||
|
() => Boolean(selectedEvent.value && selectedLocale.value) && html.value.trim().length > 0,
|
||||||
|
);
|
||||||
|
|
||||||
|
function formatLocale(locale: string): string {
|
||||||
|
const lower = locale.toLowerCase();
|
||||||
|
if (lower === "zh" || lower.startsWith("zh-")) {
|
||||||
|
return t("admin.settings.emailTemplates.localeZh");
|
||||||
|
}
|
||||||
|
if (lower === "en" || lower.startsWith("en-")) {
|
||||||
|
return t("admin.settings.emailTemplates.localeEn");
|
||||||
|
}
|
||||||
|
return locale;
|
||||||
|
}
|
||||||
|
|
||||||
|
function selectInitialLocale(locales: string[]): string {
|
||||||
|
const currentLocale = locale.value.toLowerCase();
|
||||||
|
const exactMatch = locales.find(
|
||||||
|
(availableLocale) => availableLocale.toLowerCase() === currentLocale,
|
||||||
|
);
|
||||||
|
if (exactMatch) return exactMatch;
|
||||||
|
|
||||||
|
const currentLanguage = currentLocale.split("-")[0];
|
||||||
|
const languageMatch = locales.find(
|
||||||
|
(availableLocale) => availableLocale.toLowerCase().split("-")[0] === currentLanguage,
|
||||||
|
);
|
||||||
|
if (languageMatch) return languageMatch;
|
||||||
|
|
||||||
|
return locales[0] || "";
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyTemplate(template: {
|
||||||
|
subject: string;
|
||||||
|
html: string;
|
||||||
|
is_custom?: boolean;
|
||||||
|
placeholders?: string[];
|
||||||
|
}) {
|
||||||
|
subject.value = template.subject;
|
||||||
|
html.value = template.html;
|
||||||
|
isCustomTemplate.value = template.is_custom === true;
|
||||||
|
placeholders.value = template.placeholders || [];
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadTemplate() {
|
||||||
|
if (!selectedEvent.value || !selectedLocale.value) return;
|
||||||
|
loadingTemplate.value = true;
|
||||||
|
try {
|
||||||
|
const template = await adminAPI.settings.getEmailTemplate(
|
||||||
|
selectedEvent.value,
|
||||||
|
selectedLocale.value,
|
||||||
|
);
|
||||||
|
applyTemplate(template);
|
||||||
|
await refreshPreview();
|
||||||
|
} catch (err: unknown) {
|
||||||
|
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||||
|
} finally {
|
||||||
|
loadingTemplate.value = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadTemplateList() {
|
||||||
|
loadingList.value = true;
|
||||||
|
try {
|
||||||
|
const response = await adminAPI.settings.getEmailTemplates();
|
||||||
|
eventOptions.value = response.events.map(normalizeEventOption);
|
||||||
|
localeOptions.value = response.locales;
|
||||||
|
placeholders.value = response.placeholders || [];
|
||||||
|
initializingSelection.value = true;
|
||||||
|
selectedEvent.value = eventOptions.value[0]?.value || "";
|
||||||
|
selectedLocale.value = selectInitialLocale(response.locales);
|
||||||
|
await loadTemplate();
|
||||||
|
initializingSelection.value = false;
|
||||||
|
} catch (err: unknown) {
|
||||||
|
initializingSelection.value = false;
|
||||||
|
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||||
|
} finally {
|
||||||
|
loadingList.value = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function saveTemplate() {
|
||||||
|
if (!canSave.value) {
|
||||||
|
appStore.showError(t("admin.settings.emailTemplates.validationRequired"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
saving.value = true;
|
||||||
|
try {
|
||||||
|
const template = await adminAPI.settings.updateEmailTemplate(
|
||||||
|
selectedEvent.value,
|
||||||
|
selectedLocale.value,
|
||||||
|
{
|
||||||
|
subject: subject.value,
|
||||||
|
html: html.value,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
applyTemplate(template);
|
||||||
|
await refreshPreview();
|
||||||
|
appStore.showSuccess(t("admin.settings.emailTemplates.saveSuccess"));
|
||||||
|
} catch (err: unknown) {
|
||||||
|
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||||
|
} finally {
|
||||||
|
saving.value = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function refreshPreview() {
|
||||||
|
if (!canPreview.value) {
|
||||||
|
previewSubject.value = "";
|
||||||
|
previewHtml.value = "";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
previewing.value = true;
|
||||||
|
try {
|
||||||
|
const preview = await adminAPI.settings.previewEmailTemplate({
|
||||||
|
event: selectedEvent.value,
|
||||||
|
locale: selectedLocale.value,
|
||||||
|
subject: subject.value,
|
||||||
|
html: html.value,
|
||||||
|
});
|
||||||
|
previewSubject.value = preview.subject;
|
||||||
|
previewHtml.value = preview.html;
|
||||||
|
} catch (err: unknown) {
|
||||||
|
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||||
|
} finally {
|
||||||
|
previewing.value = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function restoreOfficial() {
|
||||||
|
if (!selectedEvent.value || !selectedLocale.value) return;
|
||||||
|
if (!window.confirm(t("admin.settings.emailTemplates.restoreConfirm"))) return;
|
||||||
|
|
||||||
|
restoring.value = true;
|
||||||
|
try {
|
||||||
|
const template = await adminAPI.settings.restoreOfficialEmailTemplate(
|
||||||
|
selectedEvent.value,
|
||||||
|
selectedLocale.value,
|
||||||
|
);
|
||||||
|
applyTemplate(template);
|
||||||
|
await refreshPreview();
|
||||||
|
appStore.showSuccess(t("admin.settings.emailTemplates.restoreSuccess"));
|
||||||
|
} catch (err: unknown) {
|
||||||
|
appStore.showError(extractApiErrorMessage(err, t("common.error")));
|
||||||
|
} finally {
|
||||||
|
restoring.value = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function copyPlaceholder(placeholder: string) {
|
||||||
|
try {
|
||||||
|
await navigator.clipboard.writeText(placeholder);
|
||||||
|
appStore.showSuccess(t("admin.settings.emailTemplates.placeholderCopied"));
|
||||||
|
} catch {
|
||||||
|
appStore.showError(t("common.error"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
watch([selectedEvent, selectedLocale], ([eventValue, localeValue], [oldEvent, oldLocale]) => {
|
||||||
|
if (initializingSelection.value) return;
|
||||||
|
if (!eventValue || !localeValue) return;
|
||||||
|
if (eventValue === oldEvent && localeValue === oldLocale) return;
|
||||||
|
void loadTemplate();
|
||||||
|
});
|
||||||
|
|
||||||
|
onMounted(() => {
|
||||||
|
void loadTemplateList();
|
||||||
|
});
|
||||||
|
</script>
|
||||||
@ -13,6 +13,7 @@ export default defineConfig({
|
|||||||
test: {
|
test: {
|
||||||
globals: true,
|
globals: true,
|
||||||
environment: 'jsdom',
|
environment: 'jsdom',
|
||||||
|
setupFiles: ['./src/__tests__/setup.ts'],
|
||||||
include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'],
|
include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'],
|
||||||
exclude: ['node_modules', 'dist'],
|
exclude: ['node_modules', 'dist'],
|
||||||
coverage: {
|
coverage: {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user