diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 81225ca6..881f2e69 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -230,14 +230,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db) channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) + contentModerationRepository := repository.NewContentModerationRepository(db) + contentModerationHashCache := repository.NewContentModerationHashCache(redisClient) + contentModerationService := service.NewContentModerationService(settingRepository, contentModerationRepository, contentModerationHashCache, groupRepository, userRepository, apiKeyAuthCacheInvalidator, emailService) + contentModerationHandler := admin.NewContentModerationHandler(contentModerationService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go new file mode 100644 index 00000000..88b93527 --- /dev/null +++ b/backend/internal/handler/admin/content_moderation_handler.go @@ -0,0 +1,234 @@ +package admin + +import ( + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type ContentModerationHandler struct { + service *service.ContentModerationService +} + +func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler { + return &ContentModerationHandler{service: svc} +} + +type contentModerationConfigRequest struct { + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` +} + +type contentModerationAPIKeyTestRequest struct { + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + TimeoutMS int `json:"timeout_ms"` + Prompt string `json:"prompt"` + Images []string `json:"images"` +} + +type contentModerationHashRequest struct { + InputHash string `json:"input_hash"` +} + +func (h *ContentModerationHandler) GetConfig(c *gin.Context) { + cfg, err := h.service.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) { + var req contentModerationConfigRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{ + Enabled: req.Enabled, + Mode: req.Mode, + BaseURL: req.BaseURL, + Model: req.Model, + APIKey: req.APIKey, + APIKeys: req.APIKeys, + ClearAPIKey: req.ClearAPIKey, + TimeoutMS: req.TimeoutMS, + SampleRate: req.SampleRate, + AllGroups: req.AllGroups, + GroupIDs: req.GroupIDs, + RecordNonHits: req.RecordNonHits, + WorkerCount: req.WorkerCount, + QueueSize: req.QueueSize, + BlockStatus: req.BlockStatus, + BlockMessage: req.BlockMessage, + EmailOnHit: req.EmailOnHit, + AutoBanEnabled: req.AutoBanEnabled, + BanThreshold: req.BanThreshold, + ViolationWindowHours: req.ViolationWindowHours, + RetryCount: req.RetryCount, + HitRetentionDays: req.HitRetentionDays, + NonHitRetentionDays: req.NonHitRetentionDays, + PreHashCheckEnabled: req.PreHashCheckEnabled, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) { + var req contentModerationAPIKeyTestRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{ + APIKeys: req.APIKeys, + BaseURL: req.BaseURL, + Model: req.Model, + TimeoutMS: req.TimeoutMS, + Prompt: req.Prompt, + Images: req.Images, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) GetStatus(c *gin.Context) { + status, err := h.service.GetStatus(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, status) +} + +func (h *ContentModerationHandler) ListLogs(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := service.ContentModerationLogFilter{ + Pagination: pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortOrder: pagination.SortOrderDesc, + }, + Result: c.Query("result"), + Endpoint: c.Query("endpoint"), + Search: c.Query("search"), + } + if raw := strings.TrimSpace(c.Query("group_id")); raw != "" { + groupID, err := strconv.ParseInt(raw, 10, 64) + if err != nil || groupID <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &groupID + } + if raw := strings.TrimSpace(c.Query("from")); raw != "" { + t, _, err := parseContentModerationDate(raw) + if err != nil { + response.BadRequest(c, "Invalid from") + return + } + filter.From = &t + } + if raw := strings.TrimSpace(c.Query("to")); raw != "" { + t, dateOnly, err := parseContentModerationDate(raw) + if err != nil { + response.BadRequest(c, "Invalid to") + return + } + if dateOnly { + t = t.Add(24*time.Hour - time.Nanosecond) + } + filter.To = &t + } + items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize) +} + +func (h *ContentModerationHandler) UnbanUser(c *gin.Context) { + userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + result, err := h.service.UnbanUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) { + var req contentModerationHashRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) { + result, err := h.service.ClearFlaggedInputHashes(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func parseContentModerationDate(raw string) (time.Time, bool, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return time.Time{}, false, nil + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + return t, false, nil + } + t, err := time.Parse("2006-01-02", raw) + return t, err == nil, err +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0cec89aa..2ff94fe6 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + RiskControlEnabled: settings.RiskControlEnabled, AffiliateRebateRate: settings.AffiliateRebateRate, AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours, AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays, @@ -497,6 +498,9 @@ type UpdateSettingsRequest struct { // Affiliate (邀请返利) feature switch AffiliateEnabled *bool `json:"affiliate_enabled"` + // 风控中心功能开关 + RiskControlEnabled *bool `json:"risk_control_enabled"` + // OpenAI fast/flex policy (optional, only updated when provided) OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } @@ -1365,6 +1369,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AffiliateEnabled }(), + RiskControlEnabled: func() bool { + if req.RiskControlEnabled != nil { + return *req.RiskControlEnabled + } + return previousSettings.RiskControlEnabled + }(), } authSourceDefaults := &service.AuthSourceDefaultSettings{ @@ -1616,6 +1626,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, AffiliateEnabled: updatedSettings.AffiliateEnabled, + + RiskControlEnabled: updatedSettings.RiskControlEnabled, } if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { slog.Error("openai_fast_policy_settings_get_failed", "error", err) @@ -2004,6 +2016,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AffiliateEnabled != after.AffiliateEnabled { changed = append(changed, "affiliate_enabled") } + if before.RiskControlEnabled != after.RiskControlEnabled { + changed = append(changed, "risk_control_enabled") + } changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) return changed } diff --git a/backend/internal/handler/content_moderation_helper.go b/backend/internal/handler/content_moderation_helper.go new file mode 100644 index 00000000..af6dbd8e --- /dev/null +++ b/backend/internal/handler/content_moderation_helper.go @@ -0,0 +1,130 @@ +package handler + +import ( + "context" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if h == nil || h.contentModerationService == nil { + return nil + } + return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body) +} + +func contentModerationStatus(decision *service.ContentModerationDecision) int { + if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 { + return http.StatusForbidden + } + return decision.StatusCode +} + +func contentModerationErrorCode(decision *service.ContentModerationDecision) string { + return "content_policy_violation" +} + +func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if h == nil || h.contentModerationService == nil { + return nil + } + return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body) +} + +func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if svc == nil || c == nil || c.Request == nil { + return nil + } + input := buildContentModerationInput(c, apiKey, subject, protocol, model, body) + if reqLog != nil { + reqLog.Info("content_moderation.gateway_check_start", + zap.String("request_id", input.RequestID), + zap.Int64("user_id", input.UserID), + zap.Int64("api_key_id", input.APIKeyID), + zap.String("api_key_name", input.APIKeyName), + zap.Int64p("group_id", input.GroupID), + zap.String("group_name", input.GroupName), + zap.String("endpoint", input.Endpoint), + zap.String("provider", input.Provider), + zap.String("protocol", input.Protocol), + zap.String("model", input.Model), + zap.Int("body_bytes", len(body)), + ) + } + decision, err := svc.Check(c.Request.Context(), input) + if err != nil { + if reqLog != nil { + reqLog.Warn("content_moderation.check_failed", zap.Error(err)) + } + return nil + } + if reqLog != nil && decision != nil { + reqLog.Info("content_moderation.gateway_check_done", + zap.String("request_id", input.RequestID), + zap.Bool("allowed", decision.Allowed), + zap.Bool("blocked", decision.Blocked), + zap.Bool("flagged", decision.Flagged), + zap.String("action", decision.Action), + zap.Int("status_code", decision.StatusCode), + zap.String("highest_category", decision.HighestCategory), + zap.Float64("highest_score", decision.HighestScore), + ) + } + return decision +} + +func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput { + input := service.ContentModerationCheckInput{ + RequestID: contentModerationRequestID(c.Request.Context()), + UserID: subject.UserID, + Endpoint: GetInboundEndpoint(c), + Provider: contentModerationProvider(apiKey), + Model: strings.TrimSpace(model), + Protocol: protocol, + Body: body, + } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { + input.Provider = strings.TrimSpace(forcedPlatform) + } + if apiKey != nil { + input.APIKeyID = apiKey.ID + input.APIKeyName = apiKey.Name + if apiKey.User != nil { + input.UserEmail = apiKey.User.Email + } + if apiKey.GroupID != nil { + groupID := *apiKey.GroupID + input.GroupID = &groupID + } + if apiKey.Group != nil { + input.GroupName = apiKey.Group.Name + } + } + if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil { + input.Endpoint = c.Request.URL.Path + } + return input +} + +func contentModerationProvider(apiKey *service.APIKey) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.Platform) +} + +func contentModerationRequestID(ctx context.Context) string { + if ctx == nil { + return "" + } + if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok { + return strings.TrimSpace(requestID) + } + return "" +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 0bc834fe..fba85cf2 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -197,6 +197,9 @@ type SystemSettings struct { // Available Channels feature switch (user-facing aggregate view) AvailableChannelsEnabled bool `json:"available_channels_enabled"` + // 风控中心功能开关 + RiskControlEnabled bool `json:"risk_control_enabled"` + // Affiliate (邀请返利) feature switch AffiliateEnabled bool `json:"affiliate_enabled"` @@ -256,6 +259,8 @@ type PublicSettings struct { AvailableChannelsEnabled bool `json:"available_channels_enabled"` AffiliateEnabled bool `json:"affiliate_enabled"` + + RiskControlEnabled bool `json:"risk_control_enabled"` } // OverloadCooldownSettings 529过载冷却配置 DTO diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7b082b07..65836a7e 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -45,6 +45,7 @@ type GatewayHandler struct { apiKeyService *service.APIKeyService usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService + contentModerationService *service.ContentModerationService concurrencyHelper *ConcurrencyHelper userMsgQueueHelper *UserMsgQueueHelper maxAccountSwitches int @@ -65,6 +66,7 @@ func NewGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + contentModerationService *service.ContentModerationService, userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, settingService *service.SettingService, @@ -98,6 +100,7 @@ func NewGatewayHandler( apiKeyService: apiKeyService, usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, + contentModerationService: contentModerationService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), userMsgQueueHelper: umqHelper, maxAccountSwitches: maxAccountSwitches, @@ -189,6 +192,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Track if we've started streaming (for error handling) streamStarted := false diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 4290e54b..c6b73190 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked { + h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Error passthrough binding if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index 683cf2b7..a97f572d 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked { + h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Error passthrough binding if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2a34e3f0..90ebe9ec 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked { + googleError(c, contentModerationStatus(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 13e3ac88..308b2199 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -33,6 +33,7 @@ type AdminHandlers struct { Channel *admin.ChannelHandler ChannelMonitor *admin.ChannelMonitorHandler ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler + ContentModeration *admin.ContentModerationHandler Payment *admin.PaymentHandler Affiliate *admin.AffiliateHandler } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 06ab9d52..de384710 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3997a0ee..6b07b7ba 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -27,15 +27,16 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService - billingCacheService *service.BillingCacheService - apiKeyService *service.APIKeyService - usageRecordWorkerPool *service.UsageRecordWorkerPool - errorPassthroughService *service.ErrorPassthroughService - concurrencyHelper *ConcurrencyHelper - imageLimiter *imageConcurrencyLimiter - maxAccountSwitches int - cfg *config.Config + gatewayService *service.OpenAIGatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool + errorPassthroughService *service.ErrorPassthroughService + contentModerationService *service.ContentModerationService + concurrencyHelper *ConcurrencyHelper + imageLimiter *imageConcurrencyLimiter + maxAccountSwitches int + cfg *config.Config } func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { @@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + contentModerationService *service.ContentModerationService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, - billingCacheService: billingCacheService, - apiKeyService: apiKeyService, - usageRecordWorkerPool: usageRecordWorkerPool, - errorPassthroughService: errorPassthroughService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - imageLimiter: &imageConcurrencyLimiter{}, - maxAccountSwitches: maxAccountSwitches, - cfg: cfg, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, + errorPassthroughService: errorPassthroughService, + contentModerationService: contentModerationService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + imageLimiter: &imageConcurrencyLimiter{}, + maxAccountSwitches: maxAccountSwitches, + cfg: cfg, } } @@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body) if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) { h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) @@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked { + h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) @@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message) + return + } + if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage()) return @@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { hooks := &service.OpenAIWSIngressHooks{ InitialRequestModel: reqModel, + BeforeRequest: func(turn int, payload []byte, originalModel string) error { + if turn == 1 { + return nil + } + if !gjson.ValidBytes(payload) { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + model := strings.TrimSpace(originalModel) + if model == "" { + model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + } + if model == "" { + model = reqModel + } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) + } + return nil + }, BeforeTurn: func(turn int) error { if turn == 1 { return nil @@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s _ = conn.CloseNow() } +func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { + if conn == nil || decision == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + message := strings.TrimSpace(decision.Message) + if message == "" { + message = "content moderation blocked this request" + } + payload, err := json.Marshal(gin.H{ + "event_id": "evt_content_moderation_blocked", + "type": "error", + "error": gin.H{ + "type": "invalid_request_error", + "code": contentModerationErrorCode(decision), + "message": message, + }, + }) + if err != nil { + payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`) + } + writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + _ = conn.Write(writeCtx, coderws.MessageText, payload) +} + func summarizeWSCloseErrorForLog(err error) (string, string) { if err == nil { return "-", "-" diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index c560350e..6bddbce9 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" coderws "github.com/coder/websocket" @@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") } +type contentModerationHandlerSettingRepo struct { + values map[string]string +} + +func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + if value, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: value}, nil + } + return nil, service.ErrSettingNotFound +} + +func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + if value, ok := r.values[key]; ok { + return value, nil + } + return "", service.ErrSettingNotFound +} + +func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error { + if r.values == nil { + r.values = map[string]string{} + } + r.values[key] = value + return nil +} + +func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := map[string]string{} + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + if r.values == nil { + r.values = map[string]string{} + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error { + delete(r.values, key) + return nil +} + +type contentModerationHandlerTestRepo struct { + logs []service.ContentModerationLog +} + +func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { + if log != nil { + r.logs = append(r.logs, *log) + } + return nil +} + +func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + return 0, nil +} + +func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { + return &service.ContentModerationCleanupResult{}, nil +} + +func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) { + gin.SetMode(gin.TestMode) + + moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + _, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`)) + })) + defer moderationServer.Close() + + cfg := &service.ContentModerationConfig{ + Enabled: true, + Mode: service.ContentModerationModePreBlock, + BaseURL: moderationServer.URL, + Model: "omni-moderation-latest", + APIKeys: []string{"sk-test"}, + SampleRate: 100, + AllGroups: true, + BlockMessage: "内容审计测试阻断", + } + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationHandlerTestRepo{} + settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{ + service.SettingKeyRiskControlEnabled: "true", + service.SettingKeyContentModerationConfig: string(rawCfg), + }} + moderationSvc := service.NewContentModerationService( + settingRepo, + repo, + nil, + nil, + nil, + nil, + nil, + ) + decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{ + UserID: 1, + Endpoint: "/v1/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: service.ContentModerationProtocolOpenAIResponses, + Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + repo.logs = nil + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + contentModerationService: moderationSvc, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second), + } + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{ + "type":"response.create", + "model":"gpt-5.5", + "input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}] + }`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, payload, readErr := clientConn.Read(readCtx) + cancelRead() + if readErr == nil { + require.Contains(t, string(payload), "content_policy_violation") + require.Contains(t, string(payload), "内容审计测试阻断") + } else { + var closeErr coderws.CloseError + require.ErrorAs(t, readErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, closeErr.Reason, "内容审计测试阻断") + } + require.Len(t, repo.logs, 1) + require.True(t, repo.logs[0].Flagged) + require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action) + require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt) +} + func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) { got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`, diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index eba701f1..08a6b6e8 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted) if !acquired { return diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 22f2aa15..776d0790 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -77,5 +77,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { AvailableChannelsEnabled: settings.AvailableChannelsEnabled, AffiliateEnabled: settings.AffiliateEnabled, + + RiskControlEnabled: settings.RiskControlEnabled, }) } diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index a8725875..7f9f9e3c 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -36,6 +36,7 @@ func ProvideAdminHandlers( channelHandler *admin.ChannelHandler, channelMonitorHandler *admin.ChannelMonitorHandler, channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, + contentModerationHandler *admin.ContentModerationHandler, paymentHandler *admin.PaymentHandler, affiliateHandler *admin.AffiliateHandler, ) *AdminHandlers { @@ -67,6 +68,7 @@ func ProvideAdminHandlers( Channel: channelHandler, ChannelMonitor: channelMonitorHandler, ChannelMonitorTemplate: channelMonitorTemplateHandler, + ContentModeration: contentModerationHandler, Payment: paymentHandler, Affiliate: affiliateHandler, } @@ -170,6 +172,7 @@ var ProviderSet = wire.NewSet( admin.NewChannelHandler, admin.NewChannelMonitorHandler, admin.NewChannelMonitorRequestTemplateHandler, + admin.NewContentModerationHandler, admin.NewPaymentHandler, admin.NewAffiliateHandler, diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 68895475..43b13937 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID, + apikey.FieldName, apikey.FieldStatus, apikey.FieldIPWhitelist, apikey.FieldIPBlacklist, diff --git a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go index aba62ead..4a462ab1 100644 --- a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go +++ b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go @@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S got, err := repo.GetByKeyForAuth(ctx, key.Key) require.NoError(t, err) + require.Equal(t, key.Name, got.Name) require.NotNil(t, got.Group) require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/repository/content_moderation_hash_cache.go b/backend/internal/repository/content_moderation_hash_cache.go new file mode 100644 index 00000000..782999e7 --- /dev/null +++ b/backend/internal/repository/content_moderation_hash_cache.go @@ -0,0 +1,71 @@ +package repository + +import ( + "context" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes" + +type contentModerationHashCache struct { + rdb *redis.Client +} + +func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache { + return &contentModerationHashCache{rdb: rdb} +} + +func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return nil + } + return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err() +} + +func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return false, nil + } + return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result() +} + +func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return false, nil + } + deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result() + if err != nil { + return false, err + } + return deleted > 0, nil +} + +func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) { + if c == nil || c.rdb == nil { + return 0, nil + } + deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result() + if err != nil { + return 0, err + } + if deleted == 0 { + return 0, nil + } + if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil { + return 0, err + } + return deleted, nil +} + +func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) { + if c == nil || c.rdb == nil { + return 0, nil + } + return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result() +} diff --git a/backend/internal/repository/content_moderation_repo.go b/backend/internal/repository/content_moderation_repo.go new file mode 100644 index 00000000..6ada004a --- /dev/null +++ b/backend/internal/repository/content_moderation_repo.go @@ -0,0 +1,274 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type contentModerationRepository struct { + db *sql.DB +} + +func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository { + return &contentModerationRepository{db: db} +} + +func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { + if log == nil { + return nil + } + categoryScores, err := json.Marshal(log.CategoryScores) + if err != nil { + return fmt.Errorf("marshal moderation category scores: %w", err) + } + thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot) + if err != nil { + return fmt.Errorf("marshal moderation thresholds: %w", err) + } + var userID any + if log.UserID != nil { + userID = *log.UserID + } + var apiKeyID any + if log.APIKeyID != nil { + apiKeyID = *log.APIKeyID + } + var groupID any + if log.GroupID != nil { + groupID = *log.GroupID + } + var latency any + if log.UpstreamLatencyMS != nil { + latency = *log.UpstreamLatencyMS + } + err = r.db.QueryRowContext(ctx, ` +INSERT INTO content_moderation_logs ( + request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name, + endpoint, provider, model, mode, action, flagged, highest_category, highest_score, + category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error, + violation_count, auto_banned, email_sent, queue_delay_ms +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, $11, $12, $13, $14, $15, + $16::jsonb, $17::jsonb, $18, $19, $20, + $21, $22, $23, $24 +) RETURNING id, created_at`, + log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName, + log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore, + string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error, + log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS), + ).Scan(&log.ID, &log.CreatedAt) + if err != nil { + return fmt.Errorf("insert content moderation log: %w", err) + } + return nil +} + +func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { + where, args := buildContentModerationLogWhere(filter) + whereSQL := "WHERE " + strings.Join(where, " AND ") + + var total int64 + if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil { + return nil, nil, fmt.Errorf("count content moderation logs: %w", err) + } + + params := filter.Pagination + if params.Page <= 0 { + params.Page = 1 + } + if params.PageSize <= 0 { + params.PageSize = 20 + } + if params.PageSize > 100 { + params.PageSize = 100 + } + queryArgs := append([]any{}, args...) + queryArgs = append(queryArgs, params.Limit(), params.Offset()) + rows, err := r.db.QueryContext(ctx, ` +SELECT + l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name, + l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score, + l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error, + l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at +FROM content_moderation_logs l +LEFT JOIN users u ON u.id = l.user_id `+whereSQL+` +ORDER BY l.created_at DESC, l.id DESC +LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)), + queryArgs..., + ) + if err != nil { + return nil, nil, fmt.Errorf("list content moderation logs: %w", err) + } + defer func() { _ = rows.Close() }() + + items := make([]service.ContentModerationLog, 0) + for rows.Next() { + var item service.ContentModerationLog + var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64 + var scoresRaw, thresholdsRaw []byte + if err := rows.Scan( + &item.ID, + &item.RequestID, + &userID, + &item.UserEmail, + &apiKeyID, + &item.APIKeyName, + &groupID, + &item.GroupName, + &item.Endpoint, + &item.Provider, + &item.Model, + &item.Mode, + &item.Action, + &item.Flagged, + &item.HighestCategory, + &item.HighestScore, + &scoresRaw, + &thresholdsRaw, + &item.InputExcerpt, + &latency, + &item.Error, + &item.ViolationCount, + &item.AutoBanned, + &item.EmailSent, + &item.UserStatus, + &queueDelay, + &item.CreatedAt, + ); err != nil { + return nil, nil, fmt.Errorf("scan content moderation log: %w", err) + } + if userID.Valid { + v := userID.Int64 + item.UserID = &v + } + if apiKeyID.Valid { + v := apiKeyID.Int64 + item.APIKeyID = &v + } + if groupID.Valid { + v := groupID.Int64 + item.GroupID = &v + } + if latency.Valid { + v := int(latency.Int64) + item.UpstreamLatencyMS = &v + } + if queueDelay.Valid { + v := int(queueDelay.Int64) + item.QueueDelayMS = &v + } + item.CategoryScores = map[string]float64{} + _ = json.Unmarshal(scoresRaw, &item.CategoryScores) + item.ThresholdSnapshot = map[string]float64{} + _ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot) + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err) + } + return items, paginationResultFromTotal(total, params), nil +} + +func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + if userID <= 0 { + return 0, nil + } + var count int + err := r.db.QueryRowContext(ctx, ` +WITH last_auto_ban AS ( + SELECT MAX(created_at) AS at + FROM content_moderation_logs + WHERE user_id = $1 AND auto_banned = TRUE +) +SELECT COUNT(*) +FROM content_moderation_logs +WHERE user_id = $1 + AND flagged = TRUE + AND created_at >= $2 + AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz) +`, userID, since).Scan(&count) + if err != nil { + return 0, fmt.Errorf("count user content moderation flagged logs: %w", err) + } + return count, nil +} + +func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { + result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()} + if r == nil || r.db == nil { + return result, nil + } + hitExec, err := r.db.ExecContext(ctx, ` +DELETE FROM content_moderation_logs +WHERE flagged = TRUE AND created_at < $1 +`, hitBefore) + if err != nil { + return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err) + } + result.DeletedHit, _ = hitExec.RowsAffected() + + nonHitExec, err := r.db.ExecContext(ctx, ` +DELETE FROM content_moderation_logs +WHERE flagged = FALSE AND created_at < $1 +`, nonHitBefore) + if err != nil { + return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err) + } + result.DeletedNonHit, _ = nonHitExec.RowsAffected() + + result.FinishedAt = time.Now() + return result, nil +} + +func nullableIntPtr(value *int) any { + if value == nil { + return nil + } + return *value +} + +func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) { + where := []string{"l.id IS NOT NULL"} + args := make([]any, 0) + add := func(expr string, value any) { + args = append(args, value) + where = append(where, fmt.Sprintf(expr, len(args))) + } + switch strings.ToLower(strings.TrimSpace(filter.Result)) { + case "hit", "flagged": + where = append(where, "l.flagged = TRUE") + case "blocked", "block": + where = append(where, "l.action = 'block'") + case "pass", "allow": + where = append(where, "l.flagged = FALSE AND l.error = ''") + case "error": + where = append(where, "l.error <> ''") + } + if filter.GroupID != nil { + add("l.group_id = $%d", *filter.GroupID) + } + if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" { + add("l.endpoint = $%d", endpoint) + } + if search := strings.TrimSpace(filter.Search); search != "" { + like := "%" + search + "%" + args = append(args, like, like, like, like, like) + idx := len(args) - 4 + where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4)) + } + if filter.From != nil && !filter.From.IsZero() { + add("l.created_at >= $%d", *filter.From) + } + if filter.To != nil && !filter.To.IsZero() { + add("l.created_at <= $%d", *filter.To) + } + return where, args +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index f07bbb33..3c0ee9cb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet( NewChannelRepository, NewChannelMonitorRepository, NewChannelMonitorRequestTemplateRepository, + NewContentModerationRepository, NewAffiliateRepository, // Cache implementations @@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet( NewRefreshTokenCache, NewErrorPassthroughCache, NewTLSFingerprintProfileCache, + NewContentModerationHashCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 34f560fc..37606d94 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -792,6 +792,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "risk_control_enabled": false, "affiliate_enabled": false, "wechat_connect_enabled": false, "wechat_connect_app_id": "", @@ -983,6 +984,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "risk_control_enabled": false, "affiliate_enabled": false, "wechat_connect_enabled": true, "wechat_connect_app_id": "wx-open-config", diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 20f5d619..a2d225e0 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -92,11 +92,28 @@ func RegisterAdminRoutes( // 渠道监控 registerChannelMonitorRoutes(admin, h) + // 风控中心 + registerContentModerationRoutes(admin, h) + // 邀请返利(专属用户管理) registerAffiliateRoutes(admin, h) } } +func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + risk := admin.Group("/risk-control") + { + risk.GET("/config", h.Admin.ContentModeration.GetConfig) + risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig) + risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys) + risk.GET("/status", h.Admin.ContentModeration.GetStatus) + risk.GET("/logs", h.Admin.ContentModeration.ListLogs) + risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser) + risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash) + risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes) + } +} + func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { apiKeys := admin.Group("/api-keys") { diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 4432ad7d..3553a18a 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct { APIKeyID int64 `json:"api_key_id"` UserID int64 `json:"user_id"` GroupID *int64 `json:"group_id,omitempty"` + Name string `json:"name"` Status string `json:"status"` IPWhitelist []string `json:"ip_whitelist,omitempty"` IPBlacklist []string `json:"ip_blacklist,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 0f9d4214..877888b1 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls +const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs type apiKeyAuthCacheConfig struct { l1Size int @@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) APIKeyID: apiKey.ID, UserID: apiKey.UserID, GroupID: apiKey.GroupID, + Name: apiKey.Name, Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, @@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho UserID: snapshot.UserID, GroupID: snapshot.GroupID, Key: key, + Name: snapshot.Name, Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 8cb1b8c4..eaac9a1c 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t UserID: 2, GroupID: &groupID, Key: "k-roundtrip", + Name: "Audit Key", Status: StatusActive, User: &User{ ID: 2, @@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) require.NotNil(t, roundTrip) + require.Equal(t, apiKey.Name, roundTrip.Name) require.NotNil(t, roundTrip.Group) require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go new file mode 100644 index 00000000..192946ce --- /dev/null +++ b/backend/internal/service/content_moderation.go @@ -0,0 +1,1982 @@ +package service + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + ContentModerationModeOff = "off" + ContentModerationModeObserve = "observe" + ContentModerationModePreBlock = "pre_block" + + ContentModerationActionAllow = "allow" + ContentModerationActionBlock = "block" + ContentModerationActionHashBlock = "hash_block" + ContentModerationActionError = "error" + + ContentModerationProtocolAnthropicMessages = "anthropic_messages" + ContentModerationProtocolOpenAIResponses = "openai_responses" + ContentModerationProtocolOpenAIChat = "openai_chat_completions" + ContentModerationProtocolGemini = "gemini" + ContentModerationProtocolOpenAIImages = "openai_images" + + defaultContentModerationBaseURL = "https://api.openai.com" + defaultContentModerationModel = "omni-moderation-latest" + defaultContentModerationTimeoutMS = 3000 + maxContentModerationTimeoutMS = 30000 + maxModerationInputRunes = 12000 + maxModerationExcerptRunes = 240 + + defaultContentModerationWorkerCount = 4 + maxContentModerationWorkerCount = 32 + defaultContentModerationQueueSize = 32768 + maxContentModerationQueueSize = 100000 + defaultContentModerationBanThreshold = 10 + defaultContentModerationViolationWindowHours = 720 + defaultContentModerationBlockHTTPStatus = http.StatusForbidden + defaultContentModerationBlockMessage = "内容审计命中风险规则,请调整输入后重试" + defaultContentModerationRetryCount = 2 + maxContentModerationRetryCount = 5 + defaultContentModerationHitRetentionDays = 180 + defaultContentModerationNonHitRetentionDays = 3 + maxContentModerationRetentionDays = 3650 + maxContentModerationNonHitRetentionDays = 3 + contentModerationKeyFailureFreezeThreshold = 3 + contentModerationKeyFreezeDuration = time.Minute + maxContentModerationTestImages = 4 + maxContentModerationTestImageBytes = 8 * 1024 * 1024 + maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024 + + contentModerationCleanupInterval = 24 * time.Hour + contentModerationCleanupTimeout = 30 * time.Minute + contentModerationCleanupDelay = 5 * time.Minute +) + +var contentModerationCategoryOrder = []string{ + "harassment", + "harassment/threatening", + "hate", + "hate/threatening", + "illicit", + "illicit/violent", + "self-harm", + "self-harm/intent", + "self-harm/instructions", + "sexual", + "sexual/minors", + "violence", + "violence/graphic", +} + +func ContentModerationDefaultThresholds() map[string]float64 { + return map[string]float64{ + "harassment": 0.98, + "harassment/threatening": 0.90, + "hate": 0.65, + "hate/threatening": 0.65, + "illicit": 0.95, + "illicit/violent": 0.95, + "self-harm": 0.65, + "self-harm/intent": 0.85, + "self-harm/instructions": 0.65, + "sexual": 0.65, + "sexual/minors": 0.65, + "violence": 0.95, + "violence/graphic": 0.95, + } +} + +func ContentModerationCategories() []string { + out := make([]string, len(contentModerationCategoryOrder)) + copy(out, contentModerationCategoryOrder) + return out +} + +type ContentModerationConfig struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKey string `json:"api_key,omitempty"` + APIKeys []string `json:"api_keys,omitempty"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + Thresholds map[string]float64 `json:"thresholds"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationConfigView struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKeyConfigured bool `json:"api_key_configured"` + APIKeyMasked string `json:"api_key_masked"` + APIKeyCount int `json:"api_key_count"` + APIKeyMasks []string `json:"api_key_masks"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationAPIKeyStatus struct { + Index int `json:"index"` + KeyHash string `json:"key_hash"` + Masked string `json:"masked"` + Status string `json:"status"` + FailureCount int `json:"failure_count"` + SuccessCount int64 `json:"success_count"` + LastError string `json:"last_error"` + LastCheckedAt *time.Time `json:"last_checked_at,omitempty"` + FrozenUntil *time.Time `json:"frozen_until,omitempty"` + LastLatencyMS int `json:"last_latency_ms"` + LastHTTPStatus int `json:"last_http_status"` + LastTested bool `json:"last_tested"` + Configured bool `json:"configured"` +} + +type TestContentModerationAPIKeysInput struct { + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + TimeoutMS int `json:"timeout_ms"` + Prompt string `json:"prompt"` + Images []string `json:"images"` +} + +type TestContentModerationAPIKeysResult struct { + Items []ContentModerationAPIKeyStatus `json:"items"` + AuditResult *ContentModerationTestAuditResult `json:"audit_result,omitempty"` + ImageCount int `json:"image_count"` +} + +type ContentModerationTestAuditResult struct { + Flagged bool `json:"flagged"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CompositeScore float64 `json:"composite_score"` + CategoryScores map[string]float64 `json:"category_scores"` + Thresholds map[string]float64 `json:"thresholds"` +} + +type UpdateContentModerationConfigInput struct { + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationCheckInput struct { + RequestID string + UserID int64 + UserEmail string + APIKeyID int64 + APIKeyName string + GroupID *int64 + GroupName string + Endpoint string + Provider string + Model string + Protocol string + Body []byte +} + +type ContentModerationInput struct { + Text string + Images []string +} + +func (in *ContentModerationInput) Normalize() { + if in == nil { + return + } + in.Text = trimRunes(normalizeContentModerationText(in.Text), maxModerationInputRunes) + in.Images = normalizeModerationImages(in.Images) +} + +func (in ContentModerationInput) IsEmpty() bool { + return strings.TrimSpace(in.Text) == "" && len(in.Images) == 0 +} + +func (in ContentModerationInput) ModerationInput() any { + if len(in.Images) == 0 { + return in.Text + } + parts := make([]moderationAPIInputPart, 0, len(in.Images)+1) + if strings.TrimSpace(in.Text) != "" { + parts = append(parts, moderationAPIInputPart{Type: "text", Text: in.Text}) + } + for _, image := range in.Images { + parts = append(parts, moderationAPIInputPart{ + Type: "image_url", + ImageURL: &moderationAPIImageURLRef{URL: image}, + }) + } + return parts +} + +func (in ContentModerationInput) ExcerptText() string { + return in.Text +} + +func (in ContentModerationInput) Hash() string { + h := sha256.New() + _, _ = h.Write([]byte("text:")) + _, _ = h.Write([]byte(in.Text)) + for _, image := range in.Images { + imageHash := sha256.Sum256([]byte(image)) + _, _ = h.Write([]byte("\nimage:")) + _, _ = h.Write([]byte(hex.EncodeToString(imageHash[:]))) + } + return hex.EncodeToString(h.Sum(nil)) +} + +type ContentModerationDecision struct { + Allowed bool `json:"allowed"` + Blocked bool `json:"blocked"` + Flagged bool `json:"flagged"` + Message string `json:"message"` + StatusCode int `json:"status_code"` + InputHash string `json:"input_hash,omitempty"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CategoryScores map[string]float64 `json:"category_scores"` + Action string `json:"action"` +} + +type ContentModerationLog struct { + ID int64 `json:"id"` + RequestID string `json:"request_id"` + UserID *int64 `json:"user_id,omitempty"` + UserEmail string `json:"user_email"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + APIKeyName string `json:"api_key_name"` + GroupID *int64 `json:"group_id,omitempty"` + GroupName string `json:"group_name"` + Endpoint string `json:"endpoint"` + Provider string `json:"provider"` + Model string `json:"model"` + Mode string `json:"mode"` + Action string `json:"action"` + Flagged bool `json:"flagged"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CategoryScores map[string]float64 `json:"category_scores"` + ThresholdSnapshot map[string]float64 `json:"threshold_snapshot"` + InputExcerpt string `json:"input_excerpt"` + UpstreamLatencyMS *int `json:"upstream_latency_ms,omitempty"` + Error string `json:"error"` + ViolationCount int `json:"violation_count"` + AutoBanned bool `json:"auto_banned"` + EmailSent bool `json:"email_sent"` + UserStatus string `json:"user_status"` + QueueDelayMS *int `json:"queue_delay_ms,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type ContentModerationLogFilter struct { + Pagination pagination.PaginationParams + Result string + GroupID *int64 + Endpoint string + Search string + From *time.Time + To *time.Time +} + +type ContentModerationCleanupResult struct { + DeletedHit int64 `json:"deleted_hit"` + DeletedNonHit int64 `json:"deleted_non_hit"` + FinishedAt time.Time `json:"finished_at"` +} + +type ContentModerationRuntimeStatus struct { + Enabled bool `json:"enabled"` + RiskControlEnabled bool `json:"risk_control_enabled"` + Mode string `json:"mode"` + WorkerCount int `json:"worker_count"` + MaxWorkers int `json:"max_workers"` + ActiveWorkers int `json:"active_workers"` + IdleWorkers int `json:"idle_workers"` + QueueSize int `json:"queue_size"` + QueueLength int `json:"queue_length"` + QueueUsagePercent float64 `json:"queue_usage_percent"` + Enqueued int64 `json:"enqueued"` + Dropped int64 `json:"dropped"` + Processed int64 `json:"processed"` + Errors int64 `json:"errors"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + FlaggedHashCount int64 `json:"flagged_hash_count"` + LastCleanupAt *time.Time `json:"last_cleanup_at,omitempty"` + LastCleanupDeletedHit int64 `json:"last_cleanup_deleted_hit"` + LastCleanupDeletedNonHit int64 `json:"last_cleanup_deleted_non_hit"` +} + +type ContentModerationUnbanUserResult struct { + UserID int64 `json:"user_id"` + Status string `json:"status"` +} + +type ContentModerationDeleteHashResult struct { + InputHash string `json:"input_hash"` + Deleted bool `json:"deleted"` +} + +type ContentModerationClearHashesResult struct { + Deleted int64 `json:"deleted"` +} + +type ContentModerationRepository interface { + CreateLog(ctx context.Context, log *ContentModerationLog) error + ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) + CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) + CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) +} + +type ContentModerationHashCache interface { + RecordFlaggedInputHash(ctx context.Context, inputHash string) error + HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) + DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) + ClearFlaggedInputHashes(ctx context.Context) (int64, error) + CountFlaggedInputHashes(ctx context.Context) (int64, error) +} + +type ContentModerationService struct { + settingRepo SettingRepository + repo ContentModerationRepository + hashCache ContentModerationHashCache + groupRepo GroupRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + emailService *EmailService + httpClient *http.Client + asyncQueue chan contentModerationTask + workerCount int + apiKeyCursor atomic.Uint64 + asyncActive atomic.Int64 + asyncEnqueued atomic.Int64 + asyncDropped atomic.Int64 + asyncProcessed atomic.Int64 + asyncErrors atomic.Int64 + lastCleanupUnix atomic.Int64 + lastCleanupDeletedHit atomic.Int64 + lastCleanupDeletedNonHit atomic.Int64 + keyHealthMu sync.Mutex + keyHealth map[string]*contentModerationKeyHealth +} + +type contentModerationTask struct { + input ContentModerationCheckInput + content ContentModerationInput + inputHash string + enqueuedAt time.Time +} + +type contentModerationKeyHealth struct { + Hash string + Masked string + FailureCount int + SuccessCount int64 + LastError string + LastCheckedAt time.Time + FrozenUntil time.Time + LastLatencyMS int + LastHTTPStatus int + LastTested bool +} + +func NewContentModerationService( + settingRepo SettingRepository, + repo ContentModerationRepository, + hashCache ContentModerationHashCache, + groupRepo GroupRepository, + userRepo UserRepository, + authCacheInvalidator APIKeyAuthCacheInvalidator, + emailService *EmailService, +) *ContentModerationService { + svc := &ContentModerationService{ + settingRepo: settingRepo, + repo: repo, + hashCache: hashCache, + groupRepo: groupRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, + emailService: emailService, + httpClient: &http.Client{}, + workerCount: maxContentModerationWorkerCount, + asyncQueue: make(chan contentModerationTask, maxContentModerationQueueSize), + keyHealth: make(map[string]*contentModerationKeyHealth), + } + if settingRepo != nil && repo != nil { + for i := 0; i < svc.workerCount; i++ { + go svc.worker(i) + } + go svc.cleanupWorker() + } + return svc +} + +func (s *ContentModerationService) GetConfig(ctx context.Context) (*ContentModerationConfigView, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + return s.configView(cfg), nil +} + +func (s *ContentModerationService) UpdateConfig(ctx context.Context, input UpdateContentModerationConfigInput) (*ContentModerationConfigView, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + if input.Enabled != nil { + cfg.Enabled = *input.Enabled + } + if input.Mode != nil { + cfg.Mode = strings.TrimSpace(*input.Mode) + } + if input.BaseURL != nil { + cfg.BaseURL = strings.TrimSpace(*input.BaseURL) + } + if input.Model != nil { + cfg.Model = strings.TrimSpace(*input.Model) + } + if input.TimeoutMS != nil { + cfg.TimeoutMS = *input.TimeoutMS + } + if input.SampleRate != nil { + cfg.SampleRate = *input.SampleRate + } + if input.WorkerCount != nil { + cfg.WorkerCount = *input.WorkerCount + } + if input.QueueSize != nil { + cfg.QueueSize = *input.QueueSize + } + if input.BlockStatus != nil { + cfg.BlockStatus = *input.BlockStatus + } + if input.BlockMessage != nil { + cfg.BlockMessage = strings.TrimSpace(*input.BlockMessage) + } + if input.EmailOnHit != nil { + cfg.EmailOnHit = *input.EmailOnHit + } + if input.AutoBanEnabled != nil { + cfg.AutoBanEnabled = *input.AutoBanEnabled + } + if input.BanThreshold != nil { + cfg.BanThreshold = *input.BanThreshold + } + if input.ViolationWindowHours != nil { + cfg.ViolationWindowHours = *input.ViolationWindowHours + } + if input.RetryCount != nil { + cfg.RetryCount = *input.RetryCount + } + if input.HitRetentionDays != nil { + cfg.HitRetentionDays = *input.HitRetentionDays + } + if input.NonHitRetentionDays != nil { + cfg.NonHitRetentionDays = *input.NonHitRetentionDays + } + if input.PreHashCheckEnabled != nil { + cfg.PreHashCheckEnabled = *input.PreHashCheckEnabled + } + if input.AllGroups != nil { + cfg.AllGroups = *input.AllGroups + } + if input.GroupIDs != nil { + cfg.GroupIDs = normalizeInt64IDs(*input.GroupIDs) + } + if input.RecordNonHits != nil { + cfg.RecordNonHits = *input.RecordNonHits + } + if input.ClearAPIKey { + cfg.APIKey = "" + cfg.APIKeys = []string{} + } else { + if input.APIKeys != nil { + cfg.APIKeys = normalizeModerationAPIKeys(*input.APIKeys) + cfg.APIKey = "" + } + if input.APIKey != nil && strings.TrimSpace(*input.APIKey) != "" { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, *input.APIKey)) + cfg.APIKey = "" + } + } + if err := s.validateConfig(ctx, cfg); err != nil { + return nil, err + } + cfg.normalize() + raw, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal content moderation config: %w", err) + } + if err := s.settingRepo.Set(ctx, SettingKeyContentModerationConfig, string(raw)); err != nil { + return nil, fmt.Errorf("save content moderation config: %w", err) + } + return s.configView(cfg), nil +} + +func (s *ContentModerationService) TestAPIKeys(ctx context.Context, input TestContentModerationAPIKeysInput) (*TestContentModerationAPIKeysResult, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + keys := normalizeModerationAPIKeys(input.APIKeys) + configured := false + if len(keys) == 0 { + keys = cfg.apiKeys() + configured = true + } + if strings.TrimSpace(input.BaseURL) != "" { + cfg.BaseURL = input.BaseURL + } + if strings.TrimSpace(input.Model) != "" { + cfg.Model = input.Model + } + if input.TimeoutMS > 0 { + cfg.TimeoutMS = input.TimeoutMS + } + cfg.normalize() + testInput, imageCount, err := buildModerationTestInput(input.Prompt, input.Images) + if err != nil { + return nil, err + } + auditOnly := contentModerationTestHasAuditInput(input.Prompt, input.Images) + if configured && auditOnly { + key, ok := s.nextUsableAPIKey(cfg) + if !ok { + return &TestContentModerationAPIKeysResult{ + Items: s.apiKeyStatuses(keys), + ImageCount: imageCount, + }, nil + } + keys = []string{key} + } + if len(keys) == 0 { + return &TestContentModerationAPIKeysResult{Items: []ContentModerationAPIKeyStatus{}, ImageCount: imageCount}, nil + } + items := make([]ContentModerationAPIKeyStatus, 0, len(keys)) + var auditResult *ContentModerationTestAuditResult + for idx, key := range keys { + start := time.Now() + httpStatus := 0 + result, err := s.callModerationOnceWithInput(ctx, cfg, key, testInput, &httpStatus) + latency := int(time.Since(start).Milliseconds()) + keyHash := moderationAPIKeyHash(key) + if err != nil { + s.markAPIKeyFailure(key, err.Error(), latency, httpStatus) + } else { + s.markAPIKeySuccess(key, latency, httpStatus) + if auditResult == nil { + auditResult = buildContentModerationTestAuditResult(result, cfg.Thresholds) + } + } + status := s.apiKeyStatusForHash(idx, keyHash, maskSecretTail(key), configured) + status.LastTested = true + items = append(items, status) + } + return &TestContentModerationAPIKeysResult{Items: items, AuditResult: auditResult, ImageCount: imageCount}, nil +} + +func (s *ContentModerationService) Check(ctx context.Context, input ContentModerationCheckInput) (*ContentModerationDecision, error) { + allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + if s == nil || s.settingRepo == nil || s.repo == nil { + slog.Info("content_moderation.skip_unavailable", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if !s.isRiskControlEnabled(ctx) { + slog.Info("content_moderation.skip_feature_disabled", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + cfg, err := s.loadConfig(ctx) + if err != nil { + slog.Warn("content_moderation.skip_config_load_failed", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "error", err) + return allow, nil + } + inScope := cfg.includesGroup(input.GroupID) + slog.Info("content_moderation.config_loaded", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "provider", input.Provider, + "protocol", input.Protocol, + "model", input.Model, + "enabled", cfg.Enabled, + "mode", cfg.Mode, + "all_groups", cfg.AllGroups, + "configured_group_ids", cfg.GroupIDs, + "in_scope", inScope, + "sample_rate", cfg.SampleRate, + "api_key_count", len(cfg.apiKeys()), + "pre_hash_check_enabled", cfg.PreHashCheckEnabled, + "record_non_hits", cfg.RecordNonHits) + if !cfg.Enabled { + slog.Info("content_moderation.skip_config_disabled", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if cfg.Mode == ContentModerationModeOff { + slog.Info("content_moderation.skip_mode_off", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if !inScope { + slog.Info("content_moderation.skip_group_out_of_scope", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "all_groups", cfg.AllGroups, + "configured_group_ids", cfg.GroupIDs) + return allow, nil + } + content := ExtractContentModerationInput(input.Protocol, input.Body) + if content.IsEmpty() { + slog.Info("content_moderation.skip_empty_input", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "body_bytes", len(input.Body)) + return allow, nil + } + content.Normalize() + slog.Info("content_moderation.input_extracted", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "text_runes", len([]rune(content.Text)), + "image_count", len(content.Images)) + hashText := content.Hash() + if cfg.PreHashCheckEnabled && s.hashCache != nil { + matched, err := s.hashCache.HasFlaggedInputHash(ctx, hashText) + if err != nil { + slog.Warn("content_moderation.hash_check_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) + } + if matched { + slog.Info("content_moderation.hash_block", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "input_hash", hashText) + message := cfg.BlockMessage + if message != "" { + message = fmt.Sprintf("%s(hash: %s)", message, hashText) + } + return &ContentModerationDecision{ + Allowed: false, + Blocked: true, + Flagged: true, + Message: message, + StatusCode: cfg.BlockStatus, + InputHash: hashText, + Action: ContentModerationActionHashBlock, + }, nil + } + } + if !cfg.shouldSample(hashText) { + slog.Info("content_moderation.skip_sample_rate", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "sample_rate", cfg.SampleRate) + return allow, nil + } + if len(cfg.apiKeys()) == 0 { + slog.Warn("content_moderation.skip_no_audit_api_keys", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if cfg.Mode == ContentModerationModeObserve { + slog.Info("content_moderation.enqueue_observe", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "queue_len", len(s.asyncQueue)) + s.enqueueAsync(input, cfg, content, hashText) + return allow, nil + } + + return s.checkSync(ctx, input, cfg, content, hashText, nil, true), nil +} + +func (s *ContentModerationService) checkSync(ctx context.Context, input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string, queueDelay *int, allowBlock bool) *ContentModerationDecision { + allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + start := time.Now() + result, err := s.callModeration(ctx, cfg, content.ModerationInput()) + latency := int(time.Since(start).Milliseconds()) + if err != nil { + slog.Warn("content_moderation.audit_api_failed", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "mode", cfg.Mode, + "allow_block", allowBlock, + "queue_delay_ms", queueDelay, + "latency_ms", latency, + "error", err) + if queueDelay != nil { + s.asyncErrors.Add(1) + } + if cfg.RecordNonHits { + log := s.buildLog(input, cfg, ContentModerationActionError, false, "", 0, nil, content.ExcerptText(), &latency, queueDelay, err.Error()) + _ = s.repo.CreateLog(ctx, log) + } + return allow + } + + flagged, highestCategory, highestScore := evaluateModerationScores(result.CategoryScores, cfg.Thresholds) + action := ContentModerationActionAllow + blocked := false + if allowBlock && flagged && cfg.Mode == ContentModerationModePreBlock { + action = ContentModerationActionBlock + blocked = true + } + slog.Info("content_moderation.audit_result", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "mode", cfg.Mode, + "allow_block", allowBlock, + "flagged", flagged, + "blocked", blocked, + "action", action, + "highest_category", highestCategory, + "highest_score", highestScore, + "latency_ms", latency, + "queue_delay_ms", queueDelay) + if flagged || cfg.RecordNonHits { + log := s.buildLog(input, cfg, action, flagged, highestCategory, highestScore, result.CategoryScores, content.ExcerptText(), &latency, queueDelay, "") + if flagged && s.hashCache != nil { + if err := s.hashCache.RecordFlaggedInputHash(ctx, hashText); err != nil { + slog.Warn("content_moderation.record_hash_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) + } + } + s.applyFlaggedSideEffects(ctx, cfg, log) + _ = s.repo.CreateLog(ctx, log) + } + if blocked { + return &ContentModerationDecision{ + Allowed: false, + Blocked: true, + Flagged: true, + Message: cfg.BlockMessage, + StatusCode: cfg.BlockStatus, + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: result.CategoryScores, + Action: action, + } + } + return &ContentModerationDecision{ + Allowed: true, + Flagged: flagged, + Message: "", + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: result.CategoryScores, + Action: action, + } +} + +func (s *ContentModerationService) enqueueAsync(input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string) { + if s == nil || s.asyncQueue == nil { + return + } + queueSize := defaultContentModerationQueueSize + if cfg != nil && cfg.QueueSize > 0 { + queueSize = cfg.QueueSize + } + if len(s.asyncQueue) >= queueSize { + slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint, "queue_size", queueSize) + s.asyncDropped.Add(1) + return + } + task := contentModerationTask{ + input: input, + content: content, + inputHash: hashText, + enqueuedAt: time.Now(), + } + select { + case s.asyncQueue <- task: + s.asyncEnqueued.Add(1) + default: + slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint) + s.asyncDropped.Add(1) + } +} + +func (s *ContentModerationService) worker(id int) { + for { + ctx, cancel := context.WithTimeout(context.Background(), maxContentModerationTimeoutMS*time.Millisecond+10*time.Second) + cfg, err := s.loadConfig(ctx) + if err != nil || !cfg.Enabled || cfg.Mode == ContentModerationModeOff || len(cfg.apiKeys()) == 0 || id >= cfg.WorkerCount { + cancel() + time.Sleep(time.Second) + continue + } + task, ok := s.dequeueAsyncTask(ctx, time.Second) + if !ok { + cancel() + continue + } + func() { + defer cancel() + defer func() { + if r := recover(); r != nil { + slog.Error("content_moderation.worker_panic", "worker_id", id, "recover", r) + } + }() + if !cfg.includesGroup(task.input.GroupID) { + return + } + s.asyncActive.Add(1) + defer s.asyncActive.Add(-1) + queueDelay := int(time.Since(task.enqueuedAt).Milliseconds()) + _ = s.checkSync(ctx, task.input, cfg, task.content, task.inputHash, &queueDelay, false) + s.asyncProcessed.Add(1) + }() + } +} + +func (s *ContentModerationService) dequeueAsyncTask(ctx context.Context, idleWait time.Duration) (contentModerationTask, bool) { + var zero contentModerationTask + if s == nil || s.asyncQueue == nil { + return zero, false + } + if idleWait <= 0 { + idleWait = time.Second + } + timer := time.NewTimer(idleWait) + defer timer.Stop() + select { + case task, ok := <-s.asyncQueue: + return task, ok + case <-ctx.Done(): + return zero, false + case <-timer.C: + return zero, false + } +} + +func (s *ContentModerationService) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) { + if filter.Pagination.Page <= 0 { + filter.Pagination.Page = 1 + } + if filter.Pagination.PageSize <= 0 { + filter.Pagination.PageSize = 20 + } + if filter.Pagination.PageSize > 100 { + filter.Pagination.PageSize = 100 + } + if filter.Pagination.SortOrder == "" { + filter.Pagination.SortOrder = pagination.SortOrderDesc + } + return s.repo.ListLogs(ctx, filter) +} + +func (s *ContentModerationService) UnbanUser(ctx context.Context, userID int64) (*ContentModerationUnbanUserResult, error) { + if s == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_USER_REPOSITORY_UNAVAILABLE", "用户仓储不可用") + } + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER_ID", "用户 ID 无效") + } + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, infraerrors.NotFound("USER_NOT_FOUND", "用户不存在") + } + return nil, fmt.Errorf("get content moderation unban user: %w", err) + } + if user.Status != StatusActive { + user.Status = StatusActive + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, fmt.Errorf("update content moderation unban user: %w", err) + } + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return &ContentModerationUnbanUserResult{ + UserID: userID, + Status: StatusActive, + }, nil +} + +func (s *ContentModerationService) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (*ContentModerationDeleteHashResult, error) { + inputHash = normalizeContentModerationHash(inputHash) + if inputHash == "" { + return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_HASH", "风险输入哈希无效") + } + if s == nil || s.hashCache == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用") + } + deleted, err := s.hashCache.DeleteFlaggedInputHash(ctx, inputHash) + if err != nil { + return nil, fmt.Errorf("delete content moderation flagged hash: %w", err) + } + return &ContentModerationDeleteHashResult{ + InputHash: inputHash, + Deleted: deleted, + }, nil +} + +func (s *ContentModerationService) ClearFlaggedInputHashes(ctx context.Context) (*ContentModerationClearHashesResult, error) { + if s == nil || s.hashCache == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用") + } + deleted, err := s.hashCache.ClearFlaggedInputHashes(ctx) + if err != nil { + return nil, fmt.Errorf("clear content moderation flagged hashes: %w", err) + } + return &ContentModerationClearHashesResult{Deleted: deleted}, nil +} + +func (s *ContentModerationService) GetStatus(ctx context.Context) (*ContentModerationRuntimeStatus, error) { + if s == nil { + return &ContentModerationRuntimeStatus{}, nil + } + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + riskEnabled := s.isRiskControlEnabled(ctx) + active := int(s.asyncActive.Load()) + if active < 0 { + active = 0 + } + if active > cfg.WorkerCount { + active = cfg.WorkerCount + } + queueLength := 0 + if s.asyncQueue != nil { + queueLength = len(s.asyncQueue) + } + queueUsage := 0.0 + if cfg.QueueSize > 0 { + queueUsage = float64(queueLength) * 100 / float64(cfg.QueueSize) + } + var flaggedHashCount int64 + if s.hashCache != nil { + if n, err := s.hashCache.CountFlaggedInputHashes(ctx); err == nil { + flaggedHashCount = n + } else { + slog.Warn("content_moderation.hash_count_failed", "error", err) + } + } + var lastCleanupAt *time.Time + if unix := s.lastCleanupUnix.Load(); unix > 0 { + t := time.Unix(unix, 0) + lastCleanupAt = &t + } + return &ContentModerationRuntimeStatus{ + Enabled: cfg.Enabled, + RiskControlEnabled: riskEnabled, + Mode: cfg.Mode, + WorkerCount: cfg.WorkerCount, + MaxWorkers: maxContentModerationWorkerCount, + ActiveWorkers: active, + IdleWorkers: cfg.WorkerCount - active, + QueueSize: cfg.QueueSize, + QueueLength: queueLength, + QueueUsagePercent: queueUsage, + Enqueued: s.asyncEnqueued.Load(), + Dropped: s.asyncDropped.Load(), + Processed: s.asyncProcessed.Load(), + Errors: s.asyncErrors.Load(), + APIKeyStatuses: s.apiKeyStatuses(cfg.apiKeys()), + FlaggedHashCount: flaggedHashCount, + LastCleanupAt: lastCleanupAt, + LastCleanupDeletedHit: s.lastCleanupDeletedHit.Load(), + LastCleanupDeletedNonHit: s.lastCleanupDeletedNonHit.Load(), + }, nil +} + +func (s *ContentModerationService) cleanupWorker() { + timer := time.NewTimer(contentModerationCleanupDelay) + defer timer.Stop() + for { + <-timer.C + s.runCleanupOnce() + timer.Reset(contentModerationCleanupInterval) + } +} + +func (s *ContentModerationService) runCleanupOnce() { + if s == nil || s.repo == nil || s.settingRepo == nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), contentModerationCleanupTimeout) + defer cancel() + cfg, err := s.loadConfig(ctx) + if err != nil { + slog.Warn("content_moderation.cleanup_load_config_failed", "error", err) + return + } + now := time.Now() + hitBefore := now.AddDate(0, 0, -cfg.HitRetentionDays) + nonHitBefore := now.AddDate(0, 0, -cfg.NonHitRetentionDays) + result, err := s.repo.CleanupExpiredLogs(ctx, hitBefore, nonHitBefore) + if err != nil { + slog.Warn("content_moderation.cleanup_failed", "error", err) + return + } + if result == nil { + return + } + s.lastCleanupUnix.Store(result.FinishedAt.Unix()) + s.lastCleanupDeletedHit.Store(result.DeletedHit) + s.lastCleanupDeletedNonHit.Store(result.DeletedNonHit) +} + +func (s *ContentModerationService) loadConfig(ctx context.Context) (*ContentModerationConfig, error) { + cfg := defaultContentModerationConfig() + raw, err := s.settingRepo.GetValue(ctx, SettingKeyContentModerationConfig) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + cfg.normalize() + return cfg, nil + } + return nil, fmt.Errorf("get content moderation config: %w", err) + } + if strings.TrimSpace(raw) == "" { + cfg.normalize() + return cfg, nil + } + if err := json.Unmarshal([]byte(raw), cfg); err != nil { + return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不是有效 JSON") + } + cfg.normalize() + return cfg, nil +} + +func (s *ContentModerationService) isRiskControlEnabled(ctx context.Context) bool { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyRiskControlEnabled) + if err != nil { + return false + } + return raw == "true" +} + +func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *ContentModerationConfig) error { + if cfg == nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不能为空") + } + cfg.normalize() + switch cfg.Mode { + case ContentModerationModeOff, ContentModerationModeObserve, ContentModerationModePreBlock: + default: + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODE", "内容审计模式无效") + } + if _, err := url.ParseRequestURI(cfg.BaseURL); err != nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BASE_URL", "OpenAI Base URL 无效") + } + if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间") + } + if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil { + for _, groupID := range cfg.GroupIDs { + if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_GROUP", fmt.Sprintf("审计分组不存在: %d", groupID)) + } + } + } + return nil +} + +func (s *ContentModerationService) callModeration(ctx context.Context, cfg *ContentModerationConfig, input any) (*moderationAPIResult, error) { + attempts := cfg.RetryCount + 1 + if attempts <= 0 { + attempts = 1 + } + if attempts > maxContentModerationRetryCount+1 { + attempts = maxContentModerationRetryCount + 1 + } + var lastErr error + for attempt := 0; attempt < attempts; attempt++ { + key, ok := s.nextUsableAPIKey(cfg) + if !ok { + lastErr = errors.New("no moderation api key available") + break + } + start := time.Now() + httpStatus := 0 + result, err := s.callModerationOnceWithInput(ctx, cfg, key, input, &httpStatus) + latency := int(time.Since(start).Milliseconds()) + if err == nil { + s.markAPIKeySuccess(key, latency, httpStatus) + return result, nil + } + s.markAPIKeyFailure(key, err.Error(), latency, httpStatus) + lastErr = err + if attempt == attempts-1 { + break + } + wait := time.Duration(100*(attempt+1)) * time.Millisecond + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } + return nil, lastErr +} + +func (s *ContentModerationService) callModerationOnceWithInput(ctx context.Context, cfg *ContentModerationConfig, apiKey string, input any, httpStatus *int) (*moderationAPIResult, error) { + base := strings.TrimRight(cfg.BaseURL, "/") + endpoint, err := url.JoinPath(base, "/v1/moderations") + if err != nil { + return nil, err + } + payload := moderationAPIRequest{ + Model: cfg.Model, + Input: input, + } + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + timeout := time.Duration(cfg.TimeoutMS) * time.Millisecond + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, endpoint, bytes.NewReader(raw)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Content-Type", "application/json") + + client := s.httpClient + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if httpStatus != nil { + *httpStatus = resp.StatusCode + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("moderation api status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var out moderationAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + if len(out.Results) == 0 { + return nil, errors.New("moderation api returned empty results") + } + return &out.Results[0], nil +} + +func (s *ContentModerationService) buildLog(input ContentModerationCheckInput, cfg *ContentModerationConfig, action string, flagged bool, highestCategory string, highestScore float64, scores map[string]float64, text string, latency *int, queueDelay *int, errText string) *ContentModerationLog { + var userID *int64 + if input.UserID > 0 { + userID = &input.UserID + } + var apiKeyID *int64 + if input.APIKeyID > 0 { + apiKeyID = &input.APIKeyID + } + return &ContentModerationLog{ + RequestID: input.RequestID, + UserID: userID, + UserEmail: input.UserEmail, + APIKeyID: apiKeyID, + APIKeyName: input.APIKeyName, + GroupID: cloneInt64Ptr(input.GroupID), + GroupName: input.GroupName, + Endpoint: input.Endpoint, + Provider: input.Provider, + Model: input.Model, + Mode: cfg.Mode, + Action: action, + Flagged: flagged, + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: cloneFloatMap(scores), + ThresholdSnapshot: cloneFloatMap(cfg.Thresholds), + InputExcerpt: trimRunes(redactContentModerationSecrets(text), maxModerationExcerptRunes), + UpstreamLatencyMS: latency, + QueueDelayMS: queueDelay, + Error: errText, + } +} + +func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) { + if s == nil || cfg == nil || log == nil || !log.Flagged || log.UserID == nil || *log.UserID <= 0 { + return + } + count := 1 + if s.repo != nil && cfg.ViolationWindowHours > 0 { + since := time.Now().Add(-time.Duration(cfg.ViolationWindowHours) * time.Hour) + if n, err := s.repo.CountFlaggedByUserSince(ctx, *log.UserID, since); err == nil { + count = n + 1 + } + } + log.ViolationCount = count + autoBanJustApplied := false + if cfg.AutoBanEnabled && cfg.BanThreshold > 0 && count >= cfg.BanThreshold && s.userRepo != nil { + user, err := s.userRepo.GetByID(ctx, *log.UserID) + if err != nil { + slog.Warn("content_moderation.ban_get_user_failed", "user_id", *log.UserID, "error", err) + return + } + if user.Status != StatusDisabled { + user.Status = StatusDisabled + if err := s.userRepo.Update(ctx, user); err != nil { + slog.Warn("content_moderation.ban_update_user_failed", "user_id", *log.UserID, "error", err) + return + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, *log.UserID) + } + autoBanJustApplied = true + } + log.AutoBanned = true + } + + if s.emailService == nil || strings.TrimSpace(log.UserEmail) == "" { + return + } + emailSent := false + if cfg.EmailOnHit { + if err := s.sendViolationEmail(ctx, cfg, log); err != nil { + slog.Warn("content_moderation.email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err) + } else { + emailSent = true + } + } + if autoBanJustApplied { + if err := s.sendAccountDisabledEmail(ctx, cfg, log); err != nil { + slog.Warn("content_moderation.ban_email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err) + } else { + emailSent = true + } + } + log.EmailSent = emailSent +} + +func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { + siteName := s.siteName(ctx) + subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName)) + body := buildContentModerationViolationEmailBody(siteName, log, cfg) + return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) +} + +func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { + siteName := s.siteName(ctx) + subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName)) + body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg) + return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) +} + +func (s *ContentModerationService) siteName(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return "Sub2API" + } + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || strings.TrimSpace(name) == "" { + return "Sub2API" + } + return strings.TrimSpace(name) +} + +func defaultContentModerationConfig() *ContentModerationConfig { + return &ContentModerationConfig{ + Enabled: false, + Mode: ContentModerationModePreBlock, + BaseURL: defaultContentModerationBaseURL, + Model: defaultContentModerationModel, + TimeoutMS: defaultContentModerationTimeoutMS, + SampleRate: 100, + AllGroups: true, + GroupIDs: []int64{}, + RecordNonHits: false, + Thresholds: ContentModerationDefaultThresholds(), + WorkerCount: defaultContentModerationWorkerCount, + QueueSize: defaultContentModerationQueueSize, + BlockStatus: defaultContentModerationBlockHTTPStatus, + BlockMessage: defaultContentModerationBlockMessage, + EmailOnHit: true, + AutoBanEnabled: true, + BanThreshold: defaultContentModerationBanThreshold, + ViolationWindowHours: defaultContentModerationViolationWindowHours, + RetryCount: defaultContentModerationRetryCount, + HitRetentionDays: defaultContentModerationHitRetentionDays, + NonHitRetentionDays: defaultContentModerationNonHitRetentionDays, + PreHashCheckEnabled: false, + } +} + +func (cfg *ContentModerationConfig) normalize() { + if cfg.APIKey != "" { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, cfg.APIKey)) + cfg.APIKey = "" + } else { + cfg.APIKeys = normalizeModerationAPIKeys(cfg.APIKeys) + } + if cfg.Mode == "" { + cfg.Mode = ContentModerationModePreBlock + } + if cfg.BaseURL == "" { + cfg.BaseURL = defaultContentModerationBaseURL + } + cfg.BaseURL = strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if cfg.Model == "" { + cfg.Model = defaultContentModerationModel + } + cfg.Model = strings.TrimSpace(cfg.Model) + if cfg.TimeoutMS <= 0 { + cfg.TimeoutMS = defaultContentModerationTimeoutMS + } + if cfg.TimeoutMS > maxContentModerationTimeoutMS { + cfg.TimeoutMS = maxContentModerationTimeoutMS + } + if cfg.SampleRate < 0 { + cfg.SampleRate = 0 + } + if cfg.SampleRate > 100 { + cfg.SampleRate = 100 + } + if cfg.WorkerCount <= 0 { + cfg.WorkerCount = defaultContentModerationWorkerCount + } + if cfg.WorkerCount > maxContentModerationWorkerCount { + cfg.WorkerCount = maxContentModerationWorkerCount + } + if cfg.QueueSize <= 0 { + cfg.QueueSize = defaultContentModerationQueueSize + } + if cfg.QueueSize > maxContentModerationQueueSize { + cfg.QueueSize = maxContentModerationQueueSize + } + if strings.TrimSpace(cfg.BlockMessage) == "" { + cfg.BlockMessage = defaultContentModerationBlockMessage + } + cfg.BlockMessage = strings.TrimSpace(cfg.BlockMessage) + if cfg.BlockStatus <= 0 { + cfg.BlockStatus = defaultContentModerationBlockHTTPStatus + } + if cfg.BanThreshold <= 0 { + cfg.BanThreshold = defaultContentModerationBanThreshold + } + if cfg.ViolationWindowHours <= 0 { + cfg.ViolationWindowHours = defaultContentModerationViolationWindowHours + } + if cfg.RetryCount < 0 { + cfg.RetryCount = 0 + } + if cfg.RetryCount > maxContentModerationRetryCount { + cfg.RetryCount = maxContentModerationRetryCount + } + if cfg.HitRetentionDays <= 0 { + cfg.HitRetentionDays = defaultContentModerationHitRetentionDays + } + if cfg.HitRetentionDays > maxContentModerationRetentionDays { + cfg.HitRetentionDays = maxContentModerationRetentionDays + } + if cfg.NonHitRetentionDays <= 0 { + cfg.NonHitRetentionDays = defaultContentModerationNonHitRetentionDays + } + if cfg.NonHitRetentionDays > maxContentModerationNonHitRetentionDays { + cfg.NonHitRetentionDays = maxContentModerationNonHitRetentionDays + } + cfg.GroupIDs = normalizeInt64IDs(cfg.GroupIDs) + cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds) +} + +func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { + if cfg.AllGroups { + return true + } + if groupID == nil { + return false + } + for _, id := range cfg.GroupIDs { + if id == *groupID { + return true + } + } + return false +} + +func contentModerationLogGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func (cfg *ContentModerationConfig) shouldSample(hashText string) bool { + if cfg.SampleRate >= 100 { + return true + } + if cfg.SampleRate <= 0 { + return false + } + raw, err := hex.DecodeString(hashText) + if err != nil || len(raw) < 2 { + return true + } + return int(binary.BigEndian.Uint16(raw[:2])%100) < cfg.SampleRate +} + +func (cfg *ContentModerationConfig) apiKeys() []string { + if cfg == nil { + return nil + } + return normalizeModerationAPIKeys(cfg.APIKeys) +} + +func (s *ContentModerationService) nextUsableAPIKey(cfg *ContentModerationConfig) (string, bool) { + keys := cfg.apiKeys() + if len(keys) == 0 { + return "", false + } + now := time.Now() + for i := 0; i < len(keys); i++ { + idx := int(s.apiKeyCursor.Add(1)-1) % len(keys) + key := keys[idx] + if !s.isAPIKeyFrozen(key, now) { + return key, true + } + } + return "", false +} + +func (s *ContentModerationService) isAPIKeyFrozen(key string, now time.Time) bool { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return false + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + return state != nil && state.FrozenUntil.After(now) +} + +func (s *ContentModerationService) markAPIKeySuccess(key string, latencyMS int, httpStatus int) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + state.FailureCount = 0 + state.SuccessCount++ + state.LastError = "" + state.LastCheckedAt = time.Now() + state.FrozenUntil = time.Time{} + state.LastLatencyMS = latencyMS + state.LastHTTPStatus = httpStatus + state.LastTested = true +} + +func (s *ContentModerationService) markAPIKeyFailure(key string, errText string, latencyMS int, httpStatus int) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + state.FailureCount++ + state.LastError = trimRunes(errText, 180) + state.LastCheckedAt = time.Now() + state.LastLatencyMS = latencyMS + state.LastHTTPStatus = httpStatus + state.LastTested = true + if state.FailureCount >= contentModerationKeyFailureFreezeThreshold { + state.FrozenUntil = time.Now().Add(contentModerationKeyFreezeDuration) + } +} + +func (s *ContentModerationService) ensureAPIKeyHealthLocked(hash string, masked string) *contentModerationKeyHealth { + if s.keyHealth == nil { + s.keyHealth = make(map[string]*contentModerationKeyHealth) + } + state := s.keyHealth[hash] + if state == nil { + state = &contentModerationKeyHealth{Hash: hash} + s.keyHealth[hash] = state + } + if strings.TrimSpace(masked) != "" { + state.Masked = masked + } + return state +} + +func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *ContentModerationConfigView { + keys := cfg.apiKeys() + masks := make([]string, 0, len(keys)) + for _, key := range keys { + masks = append(masks, maskSecretTail(key)) + } + apiKeyMasked := "" + if len(masks) > 0 { + apiKeyMasked = masks[0] + } + return &ContentModerationConfigView{ + Enabled: cfg.Enabled, + Mode: cfg.Mode, + BaseURL: cfg.BaseURL, + Model: cfg.Model, + APIKeyConfigured: len(keys) > 0, + APIKeyMasked: apiKeyMasked, + APIKeyCount: len(keys), + APIKeyMasks: masks, + APIKeyStatuses: s.apiKeyStatuses(keys), + TimeoutMS: cfg.TimeoutMS, + SampleRate: cfg.SampleRate, + AllGroups: cfg.AllGroups, + GroupIDs: append([]int64(nil), cfg.GroupIDs...), + RecordNonHits: cfg.RecordNonHits, + WorkerCount: cfg.WorkerCount, + QueueSize: cfg.QueueSize, + BlockStatus: cfg.BlockStatus, + BlockMessage: cfg.BlockMessage, + EmailOnHit: cfg.EmailOnHit, + AutoBanEnabled: cfg.AutoBanEnabled, + BanThreshold: cfg.BanThreshold, + ViolationWindowHours: cfg.ViolationWindowHours, + RetryCount: cfg.RetryCount, + HitRetentionDays: cfg.HitRetentionDays, + NonHitRetentionDays: cfg.NonHitRetentionDays, + PreHashCheckEnabled: cfg.PreHashCheckEnabled, + } +} + +func (s *ContentModerationService) apiKeyStatuses(keys []string) []ContentModerationAPIKeyStatus { + out := make([]ContentModerationAPIKeyStatus, 0, len(keys)) + for idx, key := range keys { + out = append(out, s.apiKeyStatusForHash(idx, moderationAPIKeyHash(key), maskSecretTail(key), true)) + } + return out +} + +func (s *ContentModerationService) apiKeyStatusForHash(index int, hash string, masked string, configured bool) ContentModerationAPIKeyStatus { + status := ContentModerationAPIKeyStatus{ + Index: index, + KeyHash: hash, + Masked: masked, + Status: "unknown", + Configured: configured, + } + if hash == "" || s == nil { + return status + } + now := time.Now() + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + if state == nil { + return status + } + status.FailureCount = state.FailureCount + status.SuccessCount = state.SuccessCount + status.LastError = state.LastError + status.LastLatencyMS = state.LastLatencyMS + status.LastHTTPStatus = state.LastHTTPStatus + status.LastTested = state.LastTested + if !state.LastCheckedAt.IsZero() { + t := state.LastCheckedAt + status.LastCheckedAt = &t + } + if state.FrozenUntil.After(now) { + t := state.FrozenUntil + status.FrozenUntil = &t + status.Status = "frozen" + return status + } + if state.LastError != "" { + status.Status = "error" + return status + } + if state.SuccessCount > 0 || state.LastTested { + status.Status = "ok" + } + return status +} + +func moderationAPIKeyHash(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func buildModerationTestInput(prompt string, images []string) (any, int, error) { + prompt = trimRunes(normalizeContentModerationText(prompt), maxModerationInputRunes) + normalizedImages := make([]string, 0, len(images)) + for _, image := range images { + image = strings.TrimSpace(image) + if image == "" { + continue + } + if len(normalizedImages) >= maxContentModerationTestImages { + return nil, 0, infraerrors.BadRequest("TOO_MANY_MODERATION_TEST_IMAGES", fmt.Sprintf("最多上传 %d 张测试图片", maxContentModerationTestImages)) + } + if err := validateModerationTestImageDataURL(image); err != nil { + return nil, 0, err + } + normalizedImages = append(normalizedImages, image) + } + if prompt == "" && len(normalizedImages) == 0 { + return "hello", 0, nil + } + if len(normalizedImages) == 0 { + return prompt, 0, nil + } + parts := make([]moderationAPIInputPart, 0, len(normalizedImages)+1) + if prompt != "" { + parts = append(parts, moderationAPIInputPart{Type: "text", Text: prompt}) + } + for _, image := range normalizedImages { + parts = append(parts, moderationAPIInputPart{ + Type: "image_url", + ImageURL: &moderationAPIImageURLRef{URL: image}, + }) + } + return parts, len(normalizedImages), nil +} + +func contentModerationTestHasAuditInput(prompt string, images []string) bool { + if normalizeContentModerationText(prompt) != "" { + return true + } + for _, image := range images { + if strings.TrimSpace(image) != "" { + return true + } + } + return false +} + +func validateModerationTestImageDataURL(value string) error { + if len(value) > maxContentModerationTestImageDataURLBytes { + return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB") + } + if !strings.HasPrefix(value, "data:image/") { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 data:image/* base64") + } + parts := strings.SplitN(value, ",", 2) + if len(parts) != 2 || !strings.Contains(parts[0], ";base64") { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 base64 data URL") + } + raw, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片 base64 无效") + } + if len(raw) > maxContentModerationTestImageBytes { + return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB") + } + return nil +} + +func buildContentModerationTestAuditResult(result *moderationAPIResult, thresholds map[string]float64) *ContentModerationTestAuditResult { + if result == nil { + return nil + } + scores := make(map[string]float64, len(result.CategoryScores)) + for category, score := range result.CategoryScores { + scores[category] = score + } + thresholdSnapshot := mergeContentModerationThresholds(ContentModerationDefaultThresholds(), thresholds) + flagged, highestCategory, highestScore := evaluateModerationScores(scores, thresholdSnapshot) + compositeScore := highestScore + return &ContentModerationTestAuditResult{ + Flagged: flagged, + HighestCategory: highestCategory, + HighestScore: highestScore, + CompositeScore: compositeScore, + CategoryScores: scores, + Thresholds: thresholdSnapshot, + } +} + +type moderationAPIRequest struct { + Model string `json:"model"` + Input any `json:"input"` +} + +type moderationAPIInputPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *moderationAPIImageURLRef `json:"image_url,omitempty"` +} + +type moderationAPIImageURLRef struct { + URL string `json:"url"` +} + +type moderationAPIResponse struct { + Results []moderationAPIResult `json:"results"` +} + +type moderationAPIResult struct { + Flagged bool `json:"flagged"` + CategoryScores map[string]float64 `json:"category_scores"` +} + +func evaluateModerationScores(scores map[string]float64, thresholds map[string]float64) (bool, string, float64) { + flagged := false + highestCategory := "" + highestScore := 0.0 + for _, category := range contentModerationCategoryOrder { + score := scores[category] + if score > highestScore || highestCategory == "" { + highestScore = score + highestCategory = category + } + if score >= thresholds[category] { + flagged = true + } + } + for category, score := range scores { + if score > highestScore || highestCategory == "" { + highestScore = score + highestCategory = category + } + } + return flagged, highestCategory, highestScore +} + +func mergeContentModerationThresholds(base map[string]float64, override map[string]float64) map[string]float64 { + out := cloneFloatMap(base) + if out == nil { + out = map[string]float64{} + } + for _, category := range contentModerationCategoryOrder { + if v, ok := override[category]; ok { + if v < 0 { + v = 0 + } + if v > 1 { + v = 1 + } + out[category] = v + } + } + return out +} + +func normalizeInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return []int64{} + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + return out +} + +func normalizeModerationAPIKeys(keys []string) []string { + if len(keys) == 0 { + return []string{} + } + seen := make(map[string]struct{}, len(keys)) + out := make([]string, 0, len(keys)) + for _, key := range keys { + key = strings.TrimSpace(key) + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, key) + } + return out +} + +func normalizeContentModerationHash(inputHash string) string { + inputHash = strings.ToLower(strings.TrimSpace(inputHash)) + if len(inputHash) != sha256.Size*2 { + return "" + } + if _, err := hex.DecodeString(inputHash); err != nil { + return "" + } + return inputHash +} + +func cloneFloatMap(in map[string]float64) map[string]float64 { + if in == nil { + return map[string]float64{} + } + out := make(map[string]float64, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneInt64Ptr(in *int64) *int64 { + if in == nil { + return nil + } + v := *in + return &v +} + +func trimRunes(text string, max int) string { + if max <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= max { + return text + } + return string(runes[:max]) +} + +func maskSecretTail(secret string) string { + secret = strings.TrimSpace(secret) + if secret == "" { + return "" + } + if len(secret) <= 4 { + return "****" + } + return strings.Repeat("*", 8) + secret[len(secret)-4:] +} diff --git a/backend/internal/service/content_moderation_email.go b/backend/internal/service/content_moderation_email.go new file mode 100644 index 00000000..e462ff88 --- /dev/null +++ b/backend/internal/service/content_moderation_email.go @@ -0,0 +1,117 @@ +package service + +import ( + "fmt" + "html" + "strings" + "time" +) + +func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string { + if log == nil { + return "" + } + userName := strings.TrimSpace(log.UserEmail) + if userName == "" && log.UserID != nil { + userName = fmt.Sprintf("UID %d", *log.UserID) + } + threshold := cfg.BanThreshold + if threshold <= 0 { + threshold = defaultContentModerationBanThreshold + } + statusBlock := "" + if log.AutoBanned { + statusBlock = `
尊敬的用户 %s,您的 API 请求在内容审计中触发平台风控策略。详情如下。
+| 触发时间 | %s |
| 触发来源 | 内容审核 |
| 所属分组 | %s |
| 命中类别 | %s / %.3f |
| 累计触发次数 | %d 次(阈值 %d) |
此邮件由 %s 自动发送,请勿回复。
+尊敬的用户 %s,您的账户在计数周期内多次触发平台风控策略,系统已自动禁用该账户。详情如下。
+| 封禁时间 | %s |
| 触发来源 | 内容审核 |
| 所属分组 | %s |
| 命中类别 | %s / %.3f |
| 累计触发次数 | %d 次(阈值 %d) |
如需申诉或恢复账号,请联系平台管理员处理。
+此邮件由 %s 自动发送,请勿回复。
+{{ t('admin.riskControl.description') }}
+{{ item.label }}
+ + {{ item.badge }} + +{{ item.value }}
+{{ item.meta }}
+{{ t('admin.riskControl.workerStatusHint') }}
+{{ t('admin.riskControl.queueUsage') }}
++ {{ formatNumber(status?.queue_length ?? 0) }} / {{ formatNumber(status?.queue_size ?? configForm.queue_size) }} +
+{{ t('admin.riskControl.activeWorkers') }}
+{{ status?.active_workers ?? 0 }}
+{{ t('admin.riskControl.idleWorkers') }}
+{{ status?.idle_workers ?? configForm.worker_count }}
+{{ t('admin.riskControl.processed') }}
+{{ formatNumber(status?.processed ?? 0) }}
+{{ t('admin.riskControl.droppedErrors') }}
+{{ formatNumber((status?.dropped ?? 0) + (status?.errors ?? 0)) }}
+{{ t('admin.riskControl.workerPool') }}
++ {{ t('admin.riskControl.workerPoolMeta', { active: status?.active_workers ?? 0, idle: status?.idle_workers ?? configForm.worker_count, total: status?.worker_count ?? configForm.worker_count }) }} +
+{{ t('admin.riskControl.recordsHint') }}
+| {{ t('admin.riskControl.table.time') }} | +{{ t('admin.riskControl.table.group') }} | +{{ t('admin.riskControl.table.user') }} | +{{ t('admin.riskControl.table.apiKey') }} | +{{ t('admin.riskControl.table.endpoint') }} | +{{ t('admin.riskControl.table.result') }} | +{{ t('admin.riskControl.table.highest') }} | +{{ t('admin.riskControl.table.actionMeta') }} | +{{ t('admin.riskControl.table.latency') }} | +{{ t('admin.riskControl.table.input') }} | +
|---|---|---|---|---|---|---|---|---|---|
| {{ t('common.loading') }} | +|||||||||
| {{ t('admin.riskControl.emptyLogs') }} | +|||||||||
| {{ formatDateTime(row.created_at) }} | +{{ row.group_name || '-' }} | +
+ {{ row.user_email || '-' }}
+ UID {{ row.user_id }}
+ |
+ {{ row.api_key_name || '-' }} | +
+ {{ row.endpoint || '-' }}
+ {{ row.provider || '-' }} / {{ row.model || '-' }}
+ |
+ + + {{ resultLabel(row) }} + + | +
+ {{ row.highest_category || '-' }}
+ {{ percent(row.highest_score) }}
+ |
+
+ {{ violationCountText(row) }}
+
+ {{ row.email_sent ? t('admin.riskControl.emailSent') : t('admin.riskControl.emailNotSent') }}
+ / {{ t('admin.riskControl.autoBanned') }}
+
+
+ |
+
+ {{ latencyText(row.upstream_latency_ms) }}
+
+ {{ t('admin.riskControl.queueDelay', { ms: row.queue_delay_ms }) }}
+
+ |
+ + + | +
{{ t('admin.riskControl.enabled') }}
+{{ t('admin.riskControl.enabledHint') }}
+{{ modeDescription(configForm.mode) }}
++ {{ t('admin.riskControl.apiKeysHint', { count: configForm.api_key_count }) }} +
+{{ t('admin.riskControl.auditTestInput') }}
+{{ t('admin.riskControl.auditTestInputHint') }}
+{{ t('admin.riskControl.auditTestImages') }}
+{{ t('admin.riskControl.auditTestImagesHint') }}
+{{ t('admin.riskControl.apiKeyHealth') }}
+{{ t('admin.riskControl.apiKeyFreezeRule') }}
+{{ t('admin.riskControl.apiKeyHealthEmpty') }}
+{{ t('admin.riskControl.apiKeyHealthEmptyHint') }}
+{{ apiKeyStatusMeta(row) }}
++ {{ row.last_error }} +
+{{ t('admin.riskControl.auditTestResult') }}
++ {{ t('admin.riskControl.auditTestHighest', { category: moderationTestResult.highest_category || '-', score: percent(moderationTestResult.highest_score) }) }} +
+{{ t('admin.riskControl.groupScopeHint') }}
+{{ t('admin.riskControl.noGroups') }}
+{{ t('admin.riskControl.recordNonHits') }}
+{{ t('admin.riskControl.recordNonHitsHint') }}
+{{ t('admin.riskControl.preHashCheck') }}
+{{ t('admin.riskControl.preHashCheckHint') }}
++ {{ t('admin.riskControl.flaggedHashCount', { count: formatNumber(status?.flagged_hash_count ?? 0) }) }} +
+{{ t('admin.riskControl.flaggedHashHint') }}
+{{ t('admin.riskControl.emailOnHit') }}
+{{ t('admin.riskControl.emailOnHitHint') }}
+{{ t('admin.riskControl.autoBan') }}
+{{ t('admin.riskControl.autoBanHint') }}
+{{ t('admin.riskControl.table.time') }}
+{{ formatDateTime(inputDetailRow.created_at) }}
+{{ t('admin.riskControl.table.user') }}
+{{ inputDetailRow.user_email || '-' }}
+{{ t('admin.riskControl.table.result') }}
+ + {{ resultLabel(inputDetailRow) }} + +{{ t('admin.riskControl.table.highest') }}
++ {{ inputDetailRow.highest_category || '-' }} / {{ percent(inputDetailRow.highest_score) }} +
+{{ t('admin.riskControl.inputDetailContent') }}
++ {{ inputDetailRow.endpoint || '-' }} · {{ inputDetailRow.provider || '-' }} / {{ inputDetailRow.model || '-' }} +
+{{ inputDetailText }}
+ + {{ t('admin.settings.features.riskControl.description') }} +
+
+
+ {{ t('admin.settings.features.riskControl.enabledHint') }} +
+