From fff4a300c6ab0c0cc6fd394925f2811a0e17c882 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 7 May 2026 09:01:48 +0800 Subject: [PATCH] feat(risk-control): add content moderation audit --- backend/cmd/server/wire_gen.go | 10 +- .../admin/content_moderation_handler.go | 234 ++ .../internal/handler/admin/setting_handler.go | 15 + .../handler/content_moderation_helper.go | 130 ++ backend/internal/handler/dto/settings.go | 5 + backend/internal/handler/gateway_handler.go | 8 + .../gateway_handler_chat_completions.go | 5 + .../handler/gateway_handler_responses.go | 5 + .../internal/handler/gemini_v1beta_handler.go | 5 + backend/internal/handler/handler.go | 1 + .../handler/openai_chat_completions.go | 5 + .../handler/openai_gateway_handler.go | 103 +- .../handler/openai_gateway_handler_test.go | 175 ++ backend/internal/handler/openai_images.go | 4 + backend/internal/handler/setting_handler.go | 2 + backend/internal/handler/wire.go | 3 + backend/internal/repository/api_key_repo.go | 1 + ...pi_key_repo_messages_dispatch_unit_test.go | 1 + .../content_moderation_hash_cache.go | 71 + .../repository/content_moderation_repo.go | 274 +++ backend/internal/repository/wire.go | 2 + backend/internal/server/api_contract_test.go | 2 + backend/internal/server/routes/admin.go | 17 + .../internal/service/api_key_auth_cache.go | 1 + .../service/api_key_auth_cache_impl.go | 4 +- .../service/api_key_service_cache_test.go | 2 + .../internal/service/content_moderation.go | 1982 +++++++++++++++++ .../service/content_moderation_email.go | 117 + .../service/content_moderation_input.go | 307 +++ .../service/content_moderation_redact.go | 36 + .../service/content_moderation_test.go | 811 +++++++ backend/internal/service/domain_constants.go | 2 + backend/internal/service/openai_images.go | 63 + .../internal/service/openai_images_test.go | 45 + .../internal/service/openai_ws_forwarder.go | 6 + .../openai_ws_v2_passthrough_adapter.go | 13 + backend/internal/service/setting_service.go | 14 + backend/internal/service/settings_view.go | 4 + backend/internal/service/wire.go | 1 + backend/migrations/135_content_moderation.sql | 45 + frontend/src/api/admin/index.ts | 8 +- frontend/src/api/admin/riskControl.ts | 251 +++ frontend/src/api/admin/settings.ts | 2 + frontend/src/components/layout/AppSidebar.vue | 17 + frontend/src/i18n/locales/en.ts | 205 ++ frontend/src/i18n/locales/zh.ts | 205 ++ frontend/src/router/index.ts | 21 + frontend/src/router/meta.d.ts | 6 + frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 1 + frontend/src/utils/featureFlags.ts | 5 + frontend/src/views/admin/RiskControlView.vue | 1574 +++++++++++++ frontend/src/views/admin/SettingsView.vue | 35 + frontend/src/views/auth/LoginView.vue | 12 +- 54 files changed, 6840 insertions(+), 34 deletions(-) create mode 100644 backend/internal/handler/admin/content_moderation_handler.go create mode 100644 backend/internal/handler/content_moderation_helper.go create mode 100644 backend/internal/repository/content_moderation_hash_cache.go create mode 100644 backend/internal/repository/content_moderation_repo.go create mode 100644 backend/internal/service/content_moderation.go create mode 100644 backend/internal/service/content_moderation_email.go create mode 100644 backend/internal/service/content_moderation_input.go create mode 100644 backend/internal/service/content_moderation_redact.go create mode 100644 backend/internal/service/content_moderation_test.go create mode 100644 backend/migrations/135_content_moderation.sql create mode 100644 frontend/src/api/admin/riskControl.ts create mode 100644 frontend/src/views/admin/RiskControlView.vue diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 81225ca6..881f2e69 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -230,14 +230,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db) channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) + contentModerationRepository := repository.NewContentModerationRepository(db) + contentModerationHashCache := repository.NewContentModerationHashCache(redisClient) + contentModerationService := service.NewContentModerationService(settingRepository, contentModerationRepository, contentModerationHashCache, groupRepository, userRepository, apiKeyAuthCacheInvalidator, emailService) + contentModerationHandler := admin.NewContentModerationHandler(contentModerationService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, contentModerationHandler, paymentHandler, affiliateHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, userMessageQueueService, configConfig, settingService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, contentModerationService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go new file mode 100644 index 00000000..88b93527 --- /dev/null +++ b/backend/internal/handler/admin/content_moderation_handler.go @@ -0,0 +1,234 @@ +package admin + +import ( + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type ContentModerationHandler struct { + service *service.ContentModerationService +} + +func NewContentModerationHandler(svc *service.ContentModerationService) *ContentModerationHandler { + return &ContentModerationHandler{service: svc} +} + +type contentModerationConfigRequest struct { + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` +} + +type contentModerationAPIKeyTestRequest struct { + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + TimeoutMS int `json:"timeout_ms"` + Prompt string `json:"prompt"` + Images []string `json:"images"` +} + +type contentModerationHashRequest struct { + InputHash string `json:"input_hash"` +} + +func (h *ContentModerationHandler) GetConfig(c *gin.Context) { + cfg, err := h.service.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) { + var req contentModerationConfigRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + cfg, err := h.service.UpdateConfig(c.Request.Context(), service.UpdateContentModerationConfigInput{ + Enabled: req.Enabled, + Mode: req.Mode, + BaseURL: req.BaseURL, + Model: req.Model, + APIKey: req.APIKey, + APIKeys: req.APIKeys, + ClearAPIKey: req.ClearAPIKey, + TimeoutMS: req.TimeoutMS, + SampleRate: req.SampleRate, + AllGroups: req.AllGroups, + GroupIDs: req.GroupIDs, + RecordNonHits: req.RecordNonHits, + WorkerCount: req.WorkerCount, + QueueSize: req.QueueSize, + BlockStatus: req.BlockStatus, + BlockMessage: req.BlockMessage, + EmailOnHit: req.EmailOnHit, + AutoBanEnabled: req.AutoBanEnabled, + BanThreshold: req.BanThreshold, + ViolationWindowHours: req.ViolationWindowHours, + RetryCount: req.RetryCount, + HitRetentionDays: req.HitRetentionDays, + NonHitRetentionDays: req.NonHitRetentionDays, + PreHashCheckEnabled: req.PreHashCheckEnabled, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *ContentModerationHandler) TestAPIKeys(c *gin.Context) { + var req contentModerationAPIKeyTestRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + result, err := h.service.TestAPIKeys(c.Request.Context(), service.TestContentModerationAPIKeysInput{ + APIKeys: req.APIKeys, + BaseURL: req.BaseURL, + Model: req.Model, + TimeoutMS: req.TimeoutMS, + Prompt: req.Prompt, + Images: req.Images, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) GetStatus(c *gin.Context) { + status, err := h.service.GetStatus(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, status) +} + +func (h *ContentModerationHandler) ListLogs(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := service.ContentModerationLogFilter{ + Pagination: pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortOrder: pagination.SortOrderDesc, + }, + Result: c.Query("result"), + Endpoint: c.Query("endpoint"), + Search: c.Query("search"), + } + if raw := strings.TrimSpace(c.Query("group_id")); raw != "" { + groupID, err := strconv.ParseInt(raw, 10, 64) + if err != nil || groupID <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &groupID + } + if raw := strings.TrimSpace(c.Query("from")); raw != "" { + t, _, err := parseContentModerationDate(raw) + if err != nil { + response.BadRequest(c, "Invalid from") + return + } + filter.From = &t + } + if raw := strings.TrimSpace(c.Query("to")); raw != "" { + t, dateOnly, err := parseContentModerationDate(raw) + if err != nil { + response.BadRequest(c, "Invalid to") + return + } + if dateOnly { + t = t.Add(24*time.Hour - time.Nanosecond) + } + filter.To = &t + } + items, pageResult, err := h.service.ListLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, pageResult.Total, pageResult.Page, pageResult.PageSize) +} + +func (h *ContentModerationHandler) UnbanUser(c *gin.Context) { + userID, err := strconv.ParseInt(strings.TrimSpace(c.Param("user_id")), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + result, err := h.service.UnbanUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) DeleteFlaggedHash(c *gin.Context) { + var req contentModerationHashRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + result, err := h.service.DeleteFlaggedInputHash(c.Request.Context(), req.InputHash) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *ContentModerationHandler) ClearFlaggedHashes(c *gin.Context) { + result, err := h.service.ClearFlaggedInputHashes(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func parseContentModerationDate(raw string) (time.Time, bool, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return time.Time{}, false, nil + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + return t, false, nil + } + t, err := time.Parse("2006-01-02", raw) + return t, err == nil, err +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0cec89aa..2ff94fe6 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + RiskControlEnabled: settings.RiskControlEnabled, AffiliateRebateRate: settings.AffiliateRebateRate, AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours, AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays, @@ -497,6 +498,9 @@ type UpdateSettingsRequest struct { // Affiliate (邀请返利) feature switch AffiliateEnabled *bool `json:"affiliate_enabled"` + // 风控中心功能开关 + RiskControlEnabled *bool `json:"risk_control_enabled"` + // OpenAI fast/flex policy (optional, only updated when provided) OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } @@ -1365,6 +1369,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AffiliateEnabled }(), + RiskControlEnabled: func() bool { + if req.RiskControlEnabled != nil { + return *req.RiskControlEnabled + } + return previousSettings.RiskControlEnabled + }(), } authSourceDefaults := &service.AuthSourceDefaultSettings{ @@ -1616,6 +1626,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, AffiliateEnabled: updatedSettings.AffiliateEnabled, + + RiskControlEnabled: updatedSettings.RiskControlEnabled, } if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { slog.Error("openai_fast_policy_settings_get_failed", "error", err) @@ -2004,6 +2016,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AffiliateEnabled != after.AffiliateEnabled { changed = append(changed, "affiliate_enabled") } + if before.RiskControlEnabled != after.RiskControlEnabled { + changed = append(changed, "risk_control_enabled") + } changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) return changed } diff --git a/backend/internal/handler/content_moderation_helper.go b/backend/internal/handler/content_moderation_helper.go new file mode 100644 index 00000000..af6dbd8e --- /dev/null +++ b/backend/internal/handler/content_moderation_helper.go @@ -0,0 +1,130 @@ +package handler + +import ( + "context" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func (h *GatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if h == nil || h.contentModerationService == nil { + return nil + } + return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body) +} + +func contentModerationStatus(decision *service.ContentModerationDecision) int { + if decision == nil || decision.StatusCode < 400 || decision.StatusCode > 599 { + return http.StatusForbidden + } + return decision.StatusCode +} + +func contentModerationErrorCode(decision *service.ContentModerationDecision) string { + return "content_policy_violation" +} + +func (h *OpenAIGatewayHandler) checkContentModeration(c *gin.Context, reqLog *zap.Logger, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if h == nil || h.contentModerationService == nil { + return nil + } + return runContentModeration(c, reqLog, h.contentModerationService, apiKey, subject, protocol, model, body) +} + +func runContentModeration(c *gin.Context, reqLog *zap.Logger, svc *service.ContentModerationService, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) *service.ContentModerationDecision { + if svc == nil || c == nil || c.Request == nil { + return nil + } + input := buildContentModerationInput(c, apiKey, subject, protocol, model, body) + if reqLog != nil { + reqLog.Info("content_moderation.gateway_check_start", + zap.String("request_id", input.RequestID), + zap.Int64("user_id", input.UserID), + zap.Int64("api_key_id", input.APIKeyID), + zap.String("api_key_name", input.APIKeyName), + zap.Int64p("group_id", input.GroupID), + zap.String("group_name", input.GroupName), + zap.String("endpoint", input.Endpoint), + zap.String("provider", input.Provider), + zap.String("protocol", input.Protocol), + zap.String("model", input.Model), + zap.Int("body_bytes", len(body)), + ) + } + decision, err := svc.Check(c.Request.Context(), input) + if err != nil { + if reqLog != nil { + reqLog.Warn("content_moderation.check_failed", zap.Error(err)) + } + return nil + } + if reqLog != nil && decision != nil { + reqLog.Info("content_moderation.gateway_check_done", + zap.String("request_id", input.RequestID), + zap.Bool("allowed", decision.Allowed), + zap.Bool("blocked", decision.Blocked), + zap.Bool("flagged", decision.Flagged), + zap.String("action", decision.Action), + zap.Int("status_code", decision.StatusCode), + zap.String("highest_category", decision.HighestCategory), + zap.Float64("highest_score", decision.HighestScore), + ) + } + return decision +} + +func buildContentModerationInput(c *gin.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, protocol string, model string, body []byte) service.ContentModerationCheckInput { + input := service.ContentModerationCheckInput{ + RequestID: contentModerationRequestID(c.Request.Context()), + UserID: subject.UserID, + Endpoint: GetInboundEndpoint(c), + Provider: contentModerationProvider(apiKey), + Model: strings.TrimSpace(model), + Protocol: protocol, + Body: body, + } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { + input.Provider = strings.TrimSpace(forcedPlatform) + } + if apiKey != nil { + input.APIKeyID = apiKey.ID + input.APIKeyName = apiKey.Name + if apiKey.User != nil { + input.UserEmail = apiKey.User.Email + } + if apiKey.GroupID != nil { + groupID := *apiKey.GroupID + input.GroupID = &groupID + } + if apiKey.Group != nil { + input.GroupName = apiKey.Group.Name + } + } + if input.Endpoint == "" && c.Request != nil && c.Request.URL != nil { + input.Endpoint = c.Request.URL.Path + } + return input +} + +func contentModerationProvider(apiKey *service.APIKey) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.Platform) +} + +func contentModerationRequestID(ctx context.Context) string { + if ctx == nil { + return "" + } + if requestID, ok := ctx.Value(ctxkey.RequestID).(string); ok { + return strings.TrimSpace(requestID) + } + return "" +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 0bc834fe..fba85cf2 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -197,6 +197,9 @@ type SystemSettings struct { // Available Channels feature switch (user-facing aggregate view) AvailableChannelsEnabled bool `json:"available_channels_enabled"` + // 风控中心功能开关 + RiskControlEnabled bool `json:"risk_control_enabled"` + // Affiliate (邀请返利) feature switch AffiliateEnabled bool `json:"affiliate_enabled"` @@ -256,6 +259,8 @@ type PublicSettings struct { AvailableChannelsEnabled bool `json:"available_channels_enabled"` AffiliateEnabled bool `json:"affiliate_enabled"` + + RiskControlEnabled bool `json:"risk_control_enabled"` } // OverloadCooldownSettings 529过载冷却配置 DTO diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7b082b07..65836a7e 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -45,6 +45,7 @@ type GatewayHandler struct { apiKeyService *service.APIKeyService usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService + contentModerationService *service.ContentModerationService concurrencyHelper *ConcurrencyHelper userMsgQueueHelper *UserMsgQueueHelper maxAccountSwitches int @@ -65,6 +66,7 @@ func NewGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + contentModerationService *service.ContentModerationService, userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, settingService *service.SettingService, @@ -98,6 +100,7 @@ func NewGatewayHandler( apiKeyService: apiKeyService, usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, + contentModerationService: contentModerationService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), userMsgQueueHelper: umqHelper, maxAccountSwitches: maxAccountSwitches, @@ -189,6 +192,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Track if we've started streaming (for error handling) streamStarted := false diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 4290e54b..c6b73190 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -91,6 +91,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked { + h.chatCompletionsErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Error passthrough binding if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index 683cf2b7..a97f572d 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -96,6 +96,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked { + h.responsesErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // Error passthrough binding if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2a34e3f0..90ebe9ec 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -185,6 +185,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked { + googleError(c, contentModerationStatus(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 13e3ac88..308b2199 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -33,6 +33,7 @@ type AdminHandlers struct { Channel *admin.ChannelHandler ChannelMonitor *admin.ChannelMonitorHandler ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler + ContentModeration *admin.ContentModerationHandler Payment *admin.PaymentHandler Affiliate *admin.AffiliateHandler } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 06ab9d52..de384710 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -81,6 +81,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3997a0ee..6b07b7ba 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -27,15 +27,16 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService - billingCacheService *service.BillingCacheService - apiKeyService *service.APIKeyService - usageRecordWorkerPool *service.UsageRecordWorkerPool - errorPassthroughService *service.ErrorPassthroughService - concurrencyHelper *ConcurrencyHelper - imageLimiter *imageConcurrencyLimiter - maxAccountSwitches int - cfg *config.Config + gatewayService *service.OpenAIGatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool + errorPassthroughService *service.ErrorPassthroughService + contentModerationService *service.ContentModerationService + concurrencyHelper *ConcurrencyHelper + imageLimiter *imageConcurrencyLimiter + maxAccountSwitches int + cfg *config.Config } func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { @@ -53,6 +54,7 @@ func NewOpenAIGatewayHandler( apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + contentModerationService *service.ContentModerationService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -64,15 +66,16 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, - billingCacheService: billingCacheService, - apiKeyService: apiKeyService, - usageRecordWorkerPool: usageRecordWorkerPool, - errorPassthroughService: errorPassthroughService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - imageLimiter: &imageConcurrencyLimiter{}, - maxAccountSwitches: maxAccountSwitches, - cfg: cfg, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, + errorPassthroughService: errorPassthroughService, + contentModerationService: contentModerationService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + imageLimiter: &imageConcurrencyLimiter{}, + maxAccountSwitches: maxAccountSwitches, + cfg: cfg, } } @@ -189,6 +192,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body) if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) { h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) @@ -599,6 +607,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked { + h.anthropicErrorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } + // 解析渠道级模型映射 channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) @@ -1153,6 +1166,12 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, decision.Message) + return + } + if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage()) return @@ -1268,6 +1287,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { hooks := &service.OpenAIWSIngressHooks{ InitialRequestModel: reqModel, + BeforeRequest: func(turn int, payload []byte, originalModel string) error { + if turn == 1 { + return nil + } + if !gjson.ValidBytes(payload) { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + model := strings.TrimSpace(originalModel) + if model == "" { + model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + } + if model == "" { + model = reqModel + } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) + } + return nil + }, BeforeTurn: func(turn int) error { if turn == 1 { return nil @@ -1712,6 +1751,34 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s _ = conn.CloseNow() } +func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { + if conn == nil || decision == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + message := strings.TrimSpace(decision.Message) + if message == "" { + message = "content moderation blocked this request" + } + payload, err := json.Marshal(gin.H{ + "event_id": "evt_content_moderation_blocked", + "type": "error", + "error": gin.H{ + "type": "invalid_request_error", + "code": contentModerationErrorCode(decision), + "message": message, + }, + }) + if err != nil { + payload = []byte(`{"event_id":"evt_content_moderation_blocked","type":"error","error":{"type":"invalid_request_error","code":"content_policy_violation","message":"content moderation blocked this request"}}`) + } + writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + _ = conn.Write(writeCtx, coderws.MessageText, payload) +} + func summarizeWSCloseErrorForLog(err error) (string, string) { if err == nil { return "-", "-" diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index c560350e..6bddbce9 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" coderws "github.com/coder/websocket" @@ -646,6 +647,180 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") } +type contentModerationHandlerSettingRepo struct { + values map[string]string +} + +func (r *contentModerationHandlerSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + if value, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: value}, nil + } + return nil, service.ErrSettingNotFound +} + +func (r *contentModerationHandlerSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + if value, ok := r.values[key]; ok { + return value, nil + } + return "", service.ErrSettingNotFound +} + +func (r *contentModerationHandlerSettingRepo) Set(ctx context.Context, key, value string) error { + if r.values == nil { + r.values = map[string]string{} + } + r.values[key] = value + return nil +} + +func (r *contentModerationHandlerSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := map[string]string{} + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *contentModerationHandlerSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + if r.values == nil { + r.values = map[string]string{} + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *contentModerationHandlerSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key string) error { + delete(r.values, key) + return nil +} + +type contentModerationHandlerTestRepo struct { + logs []service.ContentModerationLog +} + +func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { + if log != nil { + r.logs = append(r.logs, *log) + } + return nil +} + +func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (r *contentModerationHandlerTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + return 0, nil +} + +func (r *contentModerationHandlerTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { + return &service.ContentModerationCleanupResult{}, nil +} + +func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T) { + gin.SetMode(gin.TestMode) + + moderationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + _, _ = w.Write([]byte(`{"results":[{"category_scores":{"sexual":0.9}}]}`)) + })) + defer moderationServer.Close() + + cfg := &service.ContentModerationConfig{ + Enabled: true, + Mode: service.ContentModerationModePreBlock, + BaseURL: moderationServer.URL, + Model: "omni-moderation-latest", + APIKeys: []string{"sk-test"}, + SampleRate: 100, + AllGroups: true, + BlockMessage: "内容审计测试阻断", + } + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationHandlerTestRepo{} + settingRepo := &contentModerationHandlerSettingRepo{values: map[string]string{ + service.SettingKeyRiskControlEnabled: "true", + service.SettingKeyContentModerationConfig: string(rawCfg), + }} + moderationSvc := service.NewContentModerationService( + settingRepo, + repo, + nil, + nil, + nil, + nil, + nil, + ) + decision, err := moderationSvc.Check(context.Background(), service.ContentModerationCheckInput{ + UserID: 1, + Endpoint: "/v1/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: service.ContentModerationProtocolOpenAIResponses, + Body: []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + repo.logs = nil + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + contentModerationService: moderationSvc, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(&concurrencyCacheMock{}), SSEPingFormatNone, time.Second), + } + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{ + "type":"response.create", + "model":"gpt-5.5", + "input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"bad prompt"}]}] + }`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, payload, readErr := clientConn.Read(readCtx) + cancelRead() + if readErr == nil { + require.Contains(t, string(payload), "content_policy_violation") + require.Contains(t, string(payload), "内容审计测试阻断") + } else { + var closeErr coderws.CloseError + require.ErrorAs(t, readErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, closeErr.Reason, "内容审计测试阻断") + } + require.Len(t, repo.logs, 1) + require.True(t, repo.logs[0].Flagged) + require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action) + require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt) +} + func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) { got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{ firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`, diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index eba701f1..08a6b6e8 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -85,6 +85,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) return } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIImages, parsed.Model, parsed.ModerationBody()); decision != nil && decision.Blocked { + h.errorResponse(c, contentModerationStatus(decision), contentModerationErrorCode(decision), decision.Message) + return + } imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted) if !acquired { return diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 22f2aa15..776d0790 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -77,5 +77,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { AvailableChannelsEnabled: settings.AvailableChannelsEnabled, AffiliateEnabled: settings.AffiliateEnabled, + + RiskControlEnabled: settings.RiskControlEnabled, }) } diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index a8725875..7f9f9e3c 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -36,6 +36,7 @@ func ProvideAdminHandlers( channelHandler *admin.ChannelHandler, channelMonitorHandler *admin.ChannelMonitorHandler, channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, + contentModerationHandler *admin.ContentModerationHandler, paymentHandler *admin.PaymentHandler, affiliateHandler *admin.AffiliateHandler, ) *AdminHandlers { @@ -67,6 +68,7 @@ func ProvideAdminHandlers( Channel: channelHandler, ChannelMonitor: channelMonitorHandler, ChannelMonitorTemplate: channelMonitorTemplateHandler, + ContentModeration: contentModerationHandler, Payment: paymentHandler, Affiliate: affiliateHandler, } @@ -170,6 +172,7 @@ var ProviderSet = wire.NewSet( admin.NewChannelHandler, admin.NewChannelMonitorHandler, admin.NewChannelMonitorRequestTemplateHandler, + admin.NewContentModerationHandler, admin.NewPaymentHandler, admin.NewAffiliateHandler, diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 68895475..43b13937 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -125,6 +125,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID, + apikey.FieldName, apikey.FieldStatus, apikey.FieldIPWhitelist, apikey.FieldIPBlacklist, diff --git a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go index aba62ead..4a462ab1 100644 --- a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go +++ b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go @@ -69,6 +69,7 @@ func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_S got, err := repo.GetByKeyForAuth(ctx, key.Key) require.NoError(t, err) + require.Equal(t, key.Name, got.Name) require.NotNil(t, got.Group) require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/repository/content_moderation_hash_cache.go b/backend/internal/repository/content_moderation_hash_cache.go new file mode 100644 index 00000000..782999e7 --- /dev/null +++ b/backend/internal/repository/content_moderation_hash_cache.go @@ -0,0 +1,71 @@ +package repository + +import ( + "context" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const contentModerationFlaggedHashSetKey = "content_moderation:flagged_hashes" + +type contentModerationHashCache struct { + rdb *redis.Client +} + +func NewContentModerationHashCache(rdb *redis.Client) service.ContentModerationHashCache { + return &contentModerationHashCache{rdb: rdb} +} + +func (c *contentModerationHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return nil + } + return c.rdb.SAdd(ctx, contentModerationFlaggedHashSetKey, inputHash).Err() +} + +func (c *contentModerationHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return false, nil + } + return c.rdb.SIsMember(ctx, contentModerationFlaggedHashSetKey, inputHash).Result() +} + +func (c *contentModerationHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + inputHash = strings.TrimSpace(inputHash) + if c == nil || c.rdb == nil || inputHash == "" { + return false, nil + } + deleted, err := c.rdb.SRem(ctx, contentModerationFlaggedHashSetKey, inputHash).Result() + if err != nil { + return false, err + } + return deleted > 0, nil +} + +func (c *contentModerationHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) { + if c == nil || c.rdb == nil { + return 0, nil + } + deleted, err := c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result() + if err != nil { + return 0, err + } + if deleted == 0 { + return 0, nil + } + if err := c.rdb.Del(ctx, contentModerationFlaggedHashSetKey).Err(); err != nil { + return 0, err + } + return deleted, nil +} + +func (c *contentModerationHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) { + if c == nil || c.rdb == nil { + return 0, nil + } + return c.rdb.SCard(ctx, contentModerationFlaggedHashSetKey).Result() +} diff --git a/backend/internal/repository/content_moderation_repo.go b/backend/internal/repository/content_moderation_repo.go new file mode 100644 index 00000000..6ada004a --- /dev/null +++ b/backend/internal/repository/content_moderation_repo.go @@ -0,0 +1,274 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type contentModerationRepository struct { + db *sql.DB +} + +func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository { + return &contentModerationRepository{db: db} +} + +func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { + if log == nil { + return nil + } + categoryScores, err := json.Marshal(log.CategoryScores) + if err != nil { + return fmt.Errorf("marshal moderation category scores: %w", err) + } + thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot) + if err != nil { + return fmt.Errorf("marshal moderation thresholds: %w", err) + } + var userID any + if log.UserID != nil { + userID = *log.UserID + } + var apiKeyID any + if log.APIKeyID != nil { + apiKeyID = *log.APIKeyID + } + var groupID any + if log.GroupID != nil { + groupID = *log.GroupID + } + var latency any + if log.UpstreamLatencyMS != nil { + latency = *log.UpstreamLatencyMS + } + err = r.db.QueryRowContext(ctx, ` +INSERT INTO content_moderation_logs ( + request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name, + endpoint, provider, model, mode, action, flagged, highest_category, highest_score, + category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error, + violation_count, auto_banned, email_sent, queue_delay_ms +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, $11, $12, $13, $14, $15, + $16::jsonb, $17::jsonb, $18, $19, $20, + $21, $22, $23, $24 +) RETURNING id, created_at`, + log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName, + log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore, + string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error, + log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS), + ).Scan(&log.ID, &log.CreatedAt) + if err != nil { + return fmt.Errorf("insert content moderation log: %w", err) + } + return nil +} + +func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { + where, args := buildContentModerationLogWhere(filter) + whereSQL := "WHERE " + strings.Join(where, " AND ") + + var total int64 + if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil { + return nil, nil, fmt.Errorf("count content moderation logs: %w", err) + } + + params := filter.Pagination + if params.Page <= 0 { + params.Page = 1 + } + if params.PageSize <= 0 { + params.PageSize = 20 + } + if params.PageSize > 100 { + params.PageSize = 100 + } + queryArgs := append([]any{}, args...) + queryArgs = append(queryArgs, params.Limit(), params.Offset()) + rows, err := r.db.QueryContext(ctx, ` +SELECT + l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name, + l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score, + l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error, + l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at +FROM content_moderation_logs l +LEFT JOIN users u ON u.id = l.user_id `+whereSQL+` +ORDER BY l.created_at DESC, l.id DESC +LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)), + queryArgs..., + ) + if err != nil { + return nil, nil, fmt.Errorf("list content moderation logs: %w", err) + } + defer func() { _ = rows.Close() }() + + items := make([]service.ContentModerationLog, 0) + for rows.Next() { + var item service.ContentModerationLog + var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64 + var scoresRaw, thresholdsRaw []byte + if err := rows.Scan( + &item.ID, + &item.RequestID, + &userID, + &item.UserEmail, + &apiKeyID, + &item.APIKeyName, + &groupID, + &item.GroupName, + &item.Endpoint, + &item.Provider, + &item.Model, + &item.Mode, + &item.Action, + &item.Flagged, + &item.HighestCategory, + &item.HighestScore, + &scoresRaw, + &thresholdsRaw, + &item.InputExcerpt, + &latency, + &item.Error, + &item.ViolationCount, + &item.AutoBanned, + &item.EmailSent, + &item.UserStatus, + &queueDelay, + &item.CreatedAt, + ); err != nil { + return nil, nil, fmt.Errorf("scan content moderation log: %w", err) + } + if userID.Valid { + v := userID.Int64 + item.UserID = &v + } + if apiKeyID.Valid { + v := apiKeyID.Int64 + item.APIKeyID = &v + } + if groupID.Valid { + v := groupID.Int64 + item.GroupID = &v + } + if latency.Valid { + v := int(latency.Int64) + item.UpstreamLatencyMS = &v + } + if queueDelay.Valid { + v := int(queueDelay.Int64) + item.QueueDelayMS = &v + } + item.CategoryScores = map[string]float64{} + _ = json.Unmarshal(scoresRaw, &item.CategoryScores) + item.ThresholdSnapshot = map[string]float64{} + _ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot) + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err) + } + return items, paginationResultFromTotal(total, params), nil +} + +func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + if userID <= 0 { + return 0, nil + } + var count int + err := r.db.QueryRowContext(ctx, ` +WITH last_auto_ban AS ( + SELECT MAX(created_at) AS at + FROM content_moderation_logs + WHERE user_id = $1 AND auto_banned = TRUE +) +SELECT COUNT(*) +FROM content_moderation_logs +WHERE user_id = $1 + AND flagged = TRUE + AND created_at >= $2 + AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz) +`, userID, since).Scan(&count) + if err != nil { + return 0, fmt.Errorf("count user content moderation flagged logs: %w", err) + } + return count, nil +} + +func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { + result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()} + if r == nil || r.db == nil { + return result, nil + } + hitExec, err := r.db.ExecContext(ctx, ` +DELETE FROM content_moderation_logs +WHERE flagged = TRUE AND created_at < $1 +`, hitBefore) + if err != nil { + return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err) + } + result.DeletedHit, _ = hitExec.RowsAffected() + + nonHitExec, err := r.db.ExecContext(ctx, ` +DELETE FROM content_moderation_logs +WHERE flagged = FALSE AND created_at < $1 +`, nonHitBefore) + if err != nil { + return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err) + } + result.DeletedNonHit, _ = nonHitExec.RowsAffected() + + result.FinishedAt = time.Now() + return result, nil +} + +func nullableIntPtr(value *int) any { + if value == nil { + return nil + } + return *value +} + +func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) { + where := []string{"l.id IS NOT NULL"} + args := make([]any, 0) + add := func(expr string, value any) { + args = append(args, value) + where = append(where, fmt.Sprintf(expr, len(args))) + } + switch strings.ToLower(strings.TrimSpace(filter.Result)) { + case "hit", "flagged": + where = append(where, "l.flagged = TRUE") + case "blocked", "block": + where = append(where, "l.action = 'block'") + case "pass", "allow": + where = append(where, "l.flagged = FALSE AND l.error = ''") + case "error": + where = append(where, "l.error <> ''") + } + if filter.GroupID != nil { + add("l.group_id = $%d", *filter.GroupID) + } + if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" { + add("l.endpoint = $%d", endpoint) + } + if search := strings.TrimSpace(filter.Search); search != "" { + like := "%" + search + "%" + args = append(args, like, like, like, like, like) + idx := len(args) - 4 + where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4)) + } + if filter.From != nil && !filter.From.IsZero() { + add("l.created_at >= $%d", *filter.From) + } + if filter.To != nil && !filter.To.IsZero() { + add("l.created_at <= $%d", *filter.To) + } + return where, args +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index f07bbb33..3c0ee9cb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet( NewChannelRepository, NewChannelMonitorRepository, NewChannelMonitorRequestTemplateRepository, + NewContentModerationRepository, NewAffiliateRepository, // Cache implementations @@ -119,6 +120,7 @@ var ProviderSet = wire.NewSet( NewRefreshTokenCache, NewErrorPassthroughCache, NewTLSFingerprintProfileCache, + NewContentModerationHashCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 34f560fc..37606d94 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -792,6 +792,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "risk_control_enabled": false, "affiliate_enabled": false, "wechat_connect_enabled": false, "wechat_connect_app_id": "", @@ -983,6 +984,7 @@ func TestAPIContracts(t *testing.T) { "channel_monitor_enabled": true, "channel_monitor_default_interval_seconds": 60, "available_channels_enabled": false, + "risk_control_enabled": false, "affiliate_enabled": false, "wechat_connect_enabled": true, "wechat_connect_app_id": "wx-open-config", diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 20f5d619..a2d225e0 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -92,11 +92,28 @@ func RegisterAdminRoutes( // 渠道监控 registerChannelMonitorRoutes(admin, h) + // 风控中心 + registerContentModerationRoutes(admin, h) + // 邀请返利(专属用户管理) registerAffiliateRoutes(admin, h) } } +func registerContentModerationRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + risk := admin.Group("/risk-control") + { + risk.GET("/config", h.Admin.ContentModeration.GetConfig) + risk.PUT("/config", h.Admin.ContentModeration.UpdateConfig) + risk.POST("/api-keys/test", h.Admin.ContentModeration.TestAPIKeys) + risk.GET("/status", h.Admin.ContentModeration.GetStatus) + risk.GET("/logs", h.Admin.ContentModeration.ListLogs) + risk.POST("/users/:user_id/unban", h.Admin.ContentModeration.UnbanUser) + risk.DELETE("/hashes", h.Admin.ContentModeration.DeleteFlaggedHash) + risk.DELETE("/hashes/all", h.Admin.ContentModeration.ClearFlaggedHashes) + } +} + func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { apiKeys := admin.Group("/api-keys") { diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 4432ad7d..3553a18a 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -8,6 +8,7 @@ type APIKeyAuthSnapshot struct { APIKeyID int64 `json:"api_key_id"` UserID int64 `json:"user_id"` GroupID *int64 `json:"group_id,omitempty"` + Name string `json:"name"` Status string `json:"status"` IPWhitelist []string `json:"ip_whitelist,omitempty"` IPBlacklist []string `json:"ip_blacklist,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 0f9d4214..877888b1 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls +const apiKeyAuthSnapshotVersion = 9 // v9: added API Key name for audit logs type apiKeyAuthCacheConfig struct { l1Size int @@ -210,6 +210,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) APIKeyID: apiKey.ID, UserID: apiKey.UserID, GroupID: apiKey.GroupID, + Name: apiKey.Name, Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, @@ -286,6 +287,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho UserID: snapshot.UserID, GroupID: snapshot.GroupID, Key: key, + Name: snapshot.Name, Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 8cb1b8c4..eaac9a1c 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -235,6 +235,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t UserID: 2, GroupID: &groupID, Key: "k-roundtrip", + Name: "Audit Key", Status: StatusActive, User: &User{ ID: 2, @@ -267,6 +268,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) require.NotNil(t, roundTrip) + require.Equal(t, apiKey.Name, roundTrip.Name) require.NotNil(t, roundTrip.Group) require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go new file mode 100644 index 00000000..192946ce --- /dev/null +++ b/backend/internal/service/content_moderation.go @@ -0,0 +1,1982 @@ +package service + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + ContentModerationModeOff = "off" + ContentModerationModeObserve = "observe" + ContentModerationModePreBlock = "pre_block" + + ContentModerationActionAllow = "allow" + ContentModerationActionBlock = "block" + ContentModerationActionHashBlock = "hash_block" + ContentModerationActionError = "error" + + ContentModerationProtocolAnthropicMessages = "anthropic_messages" + ContentModerationProtocolOpenAIResponses = "openai_responses" + ContentModerationProtocolOpenAIChat = "openai_chat_completions" + ContentModerationProtocolGemini = "gemini" + ContentModerationProtocolOpenAIImages = "openai_images" + + defaultContentModerationBaseURL = "https://api.openai.com" + defaultContentModerationModel = "omni-moderation-latest" + defaultContentModerationTimeoutMS = 3000 + maxContentModerationTimeoutMS = 30000 + maxModerationInputRunes = 12000 + maxModerationExcerptRunes = 240 + + defaultContentModerationWorkerCount = 4 + maxContentModerationWorkerCount = 32 + defaultContentModerationQueueSize = 32768 + maxContentModerationQueueSize = 100000 + defaultContentModerationBanThreshold = 10 + defaultContentModerationViolationWindowHours = 720 + defaultContentModerationBlockHTTPStatus = http.StatusForbidden + defaultContentModerationBlockMessage = "内容审计命中风险规则,请调整输入后重试" + defaultContentModerationRetryCount = 2 + maxContentModerationRetryCount = 5 + defaultContentModerationHitRetentionDays = 180 + defaultContentModerationNonHitRetentionDays = 3 + maxContentModerationRetentionDays = 3650 + maxContentModerationNonHitRetentionDays = 3 + contentModerationKeyFailureFreezeThreshold = 3 + contentModerationKeyFreezeDuration = time.Minute + maxContentModerationTestImages = 4 + maxContentModerationTestImageBytes = 8 * 1024 * 1024 + maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024 + + contentModerationCleanupInterval = 24 * time.Hour + contentModerationCleanupTimeout = 30 * time.Minute + contentModerationCleanupDelay = 5 * time.Minute +) + +var contentModerationCategoryOrder = []string{ + "harassment", + "harassment/threatening", + "hate", + "hate/threatening", + "illicit", + "illicit/violent", + "self-harm", + "self-harm/intent", + "self-harm/instructions", + "sexual", + "sexual/minors", + "violence", + "violence/graphic", +} + +func ContentModerationDefaultThresholds() map[string]float64 { + return map[string]float64{ + "harassment": 0.98, + "harassment/threatening": 0.90, + "hate": 0.65, + "hate/threatening": 0.65, + "illicit": 0.95, + "illicit/violent": 0.95, + "self-harm": 0.65, + "self-harm/intent": 0.85, + "self-harm/instructions": 0.65, + "sexual": 0.65, + "sexual/minors": 0.65, + "violence": 0.95, + "violence/graphic": 0.95, + } +} + +func ContentModerationCategories() []string { + out := make([]string, len(contentModerationCategoryOrder)) + copy(out, contentModerationCategoryOrder) + return out +} + +type ContentModerationConfig struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKey string `json:"api_key,omitempty"` + APIKeys []string `json:"api_keys,omitempty"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + Thresholds map[string]float64 `json:"thresholds"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationConfigView struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKeyConfigured bool `json:"api_key_configured"` + APIKeyMasked string `json:"api_key_masked"` + APIKeyCount int `json:"api_key_count"` + APIKeyMasks []string `json:"api_key_masks"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationAPIKeyStatus struct { + Index int `json:"index"` + KeyHash string `json:"key_hash"` + Masked string `json:"masked"` + Status string `json:"status"` + FailureCount int `json:"failure_count"` + SuccessCount int64 `json:"success_count"` + LastError string `json:"last_error"` + LastCheckedAt *time.Time `json:"last_checked_at,omitempty"` + FrozenUntil *time.Time `json:"frozen_until,omitempty"` + LastLatencyMS int `json:"last_latency_ms"` + LastHTTPStatus int `json:"last_http_status"` + LastTested bool `json:"last_tested"` + Configured bool `json:"configured"` +} + +type TestContentModerationAPIKeysInput struct { + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + TimeoutMS int `json:"timeout_ms"` + Prompt string `json:"prompt"` + Images []string `json:"images"` +} + +type TestContentModerationAPIKeysResult struct { + Items []ContentModerationAPIKeyStatus `json:"items"` + AuditResult *ContentModerationTestAuditResult `json:"audit_result,omitempty"` + ImageCount int `json:"image_count"` +} + +type ContentModerationTestAuditResult struct { + Flagged bool `json:"flagged"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CompositeScore float64 `json:"composite_score"` + CategoryScores map[string]float64 `json:"category_scores"` + Thresholds map[string]float64 `json:"thresholds"` +} + +type UpdateContentModerationConfigInput struct { + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` +} + +type ContentModerationCheckInput struct { + RequestID string + UserID int64 + UserEmail string + APIKeyID int64 + APIKeyName string + GroupID *int64 + GroupName string + Endpoint string + Provider string + Model string + Protocol string + Body []byte +} + +type ContentModerationInput struct { + Text string + Images []string +} + +func (in *ContentModerationInput) Normalize() { + if in == nil { + return + } + in.Text = trimRunes(normalizeContentModerationText(in.Text), maxModerationInputRunes) + in.Images = normalizeModerationImages(in.Images) +} + +func (in ContentModerationInput) IsEmpty() bool { + return strings.TrimSpace(in.Text) == "" && len(in.Images) == 0 +} + +func (in ContentModerationInput) ModerationInput() any { + if len(in.Images) == 0 { + return in.Text + } + parts := make([]moderationAPIInputPart, 0, len(in.Images)+1) + if strings.TrimSpace(in.Text) != "" { + parts = append(parts, moderationAPIInputPart{Type: "text", Text: in.Text}) + } + for _, image := range in.Images { + parts = append(parts, moderationAPIInputPart{ + Type: "image_url", + ImageURL: &moderationAPIImageURLRef{URL: image}, + }) + } + return parts +} + +func (in ContentModerationInput) ExcerptText() string { + return in.Text +} + +func (in ContentModerationInput) Hash() string { + h := sha256.New() + _, _ = h.Write([]byte("text:")) + _, _ = h.Write([]byte(in.Text)) + for _, image := range in.Images { + imageHash := sha256.Sum256([]byte(image)) + _, _ = h.Write([]byte("\nimage:")) + _, _ = h.Write([]byte(hex.EncodeToString(imageHash[:]))) + } + return hex.EncodeToString(h.Sum(nil)) +} + +type ContentModerationDecision struct { + Allowed bool `json:"allowed"` + Blocked bool `json:"blocked"` + Flagged bool `json:"flagged"` + Message string `json:"message"` + StatusCode int `json:"status_code"` + InputHash string `json:"input_hash,omitempty"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CategoryScores map[string]float64 `json:"category_scores"` + Action string `json:"action"` +} + +type ContentModerationLog struct { + ID int64 `json:"id"` + RequestID string `json:"request_id"` + UserID *int64 `json:"user_id,omitempty"` + UserEmail string `json:"user_email"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + APIKeyName string `json:"api_key_name"` + GroupID *int64 `json:"group_id,omitempty"` + GroupName string `json:"group_name"` + Endpoint string `json:"endpoint"` + Provider string `json:"provider"` + Model string `json:"model"` + Mode string `json:"mode"` + Action string `json:"action"` + Flagged bool `json:"flagged"` + HighestCategory string `json:"highest_category"` + HighestScore float64 `json:"highest_score"` + CategoryScores map[string]float64 `json:"category_scores"` + ThresholdSnapshot map[string]float64 `json:"threshold_snapshot"` + InputExcerpt string `json:"input_excerpt"` + UpstreamLatencyMS *int `json:"upstream_latency_ms,omitempty"` + Error string `json:"error"` + ViolationCount int `json:"violation_count"` + AutoBanned bool `json:"auto_banned"` + EmailSent bool `json:"email_sent"` + UserStatus string `json:"user_status"` + QueueDelayMS *int `json:"queue_delay_ms,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type ContentModerationLogFilter struct { + Pagination pagination.PaginationParams + Result string + GroupID *int64 + Endpoint string + Search string + From *time.Time + To *time.Time +} + +type ContentModerationCleanupResult struct { + DeletedHit int64 `json:"deleted_hit"` + DeletedNonHit int64 `json:"deleted_non_hit"` + FinishedAt time.Time `json:"finished_at"` +} + +type ContentModerationRuntimeStatus struct { + Enabled bool `json:"enabled"` + RiskControlEnabled bool `json:"risk_control_enabled"` + Mode string `json:"mode"` + WorkerCount int `json:"worker_count"` + MaxWorkers int `json:"max_workers"` + ActiveWorkers int `json:"active_workers"` + IdleWorkers int `json:"idle_workers"` + QueueSize int `json:"queue_size"` + QueueLength int `json:"queue_length"` + QueueUsagePercent float64 `json:"queue_usage_percent"` + Enqueued int64 `json:"enqueued"` + Dropped int64 `json:"dropped"` + Processed int64 `json:"processed"` + Errors int64 `json:"errors"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + FlaggedHashCount int64 `json:"flagged_hash_count"` + LastCleanupAt *time.Time `json:"last_cleanup_at,omitempty"` + LastCleanupDeletedHit int64 `json:"last_cleanup_deleted_hit"` + LastCleanupDeletedNonHit int64 `json:"last_cleanup_deleted_non_hit"` +} + +type ContentModerationUnbanUserResult struct { + UserID int64 `json:"user_id"` + Status string `json:"status"` +} + +type ContentModerationDeleteHashResult struct { + InputHash string `json:"input_hash"` + Deleted bool `json:"deleted"` +} + +type ContentModerationClearHashesResult struct { + Deleted int64 `json:"deleted"` +} + +type ContentModerationRepository interface { + CreateLog(ctx context.Context, log *ContentModerationLog) error + ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) + CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) + CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) +} + +type ContentModerationHashCache interface { + RecordFlaggedInputHash(ctx context.Context, inputHash string) error + HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) + DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) + ClearFlaggedInputHashes(ctx context.Context) (int64, error) + CountFlaggedInputHashes(ctx context.Context) (int64, error) +} + +type ContentModerationService struct { + settingRepo SettingRepository + repo ContentModerationRepository + hashCache ContentModerationHashCache + groupRepo GroupRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + emailService *EmailService + httpClient *http.Client + asyncQueue chan contentModerationTask + workerCount int + apiKeyCursor atomic.Uint64 + asyncActive atomic.Int64 + asyncEnqueued atomic.Int64 + asyncDropped atomic.Int64 + asyncProcessed atomic.Int64 + asyncErrors atomic.Int64 + lastCleanupUnix atomic.Int64 + lastCleanupDeletedHit atomic.Int64 + lastCleanupDeletedNonHit atomic.Int64 + keyHealthMu sync.Mutex + keyHealth map[string]*contentModerationKeyHealth +} + +type contentModerationTask struct { + input ContentModerationCheckInput + content ContentModerationInput + inputHash string + enqueuedAt time.Time +} + +type contentModerationKeyHealth struct { + Hash string + Masked string + FailureCount int + SuccessCount int64 + LastError string + LastCheckedAt time.Time + FrozenUntil time.Time + LastLatencyMS int + LastHTTPStatus int + LastTested bool +} + +func NewContentModerationService( + settingRepo SettingRepository, + repo ContentModerationRepository, + hashCache ContentModerationHashCache, + groupRepo GroupRepository, + userRepo UserRepository, + authCacheInvalidator APIKeyAuthCacheInvalidator, + emailService *EmailService, +) *ContentModerationService { + svc := &ContentModerationService{ + settingRepo: settingRepo, + repo: repo, + hashCache: hashCache, + groupRepo: groupRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, + emailService: emailService, + httpClient: &http.Client{}, + workerCount: maxContentModerationWorkerCount, + asyncQueue: make(chan contentModerationTask, maxContentModerationQueueSize), + keyHealth: make(map[string]*contentModerationKeyHealth), + } + if settingRepo != nil && repo != nil { + for i := 0; i < svc.workerCount; i++ { + go svc.worker(i) + } + go svc.cleanupWorker() + } + return svc +} + +func (s *ContentModerationService) GetConfig(ctx context.Context) (*ContentModerationConfigView, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + return s.configView(cfg), nil +} + +func (s *ContentModerationService) UpdateConfig(ctx context.Context, input UpdateContentModerationConfigInput) (*ContentModerationConfigView, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + if input.Enabled != nil { + cfg.Enabled = *input.Enabled + } + if input.Mode != nil { + cfg.Mode = strings.TrimSpace(*input.Mode) + } + if input.BaseURL != nil { + cfg.BaseURL = strings.TrimSpace(*input.BaseURL) + } + if input.Model != nil { + cfg.Model = strings.TrimSpace(*input.Model) + } + if input.TimeoutMS != nil { + cfg.TimeoutMS = *input.TimeoutMS + } + if input.SampleRate != nil { + cfg.SampleRate = *input.SampleRate + } + if input.WorkerCount != nil { + cfg.WorkerCount = *input.WorkerCount + } + if input.QueueSize != nil { + cfg.QueueSize = *input.QueueSize + } + if input.BlockStatus != nil { + cfg.BlockStatus = *input.BlockStatus + } + if input.BlockMessage != nil { + cfg.BlockMessage = strings.TrimSpace(*input.BlockMessage) + } + if input.EmailOnHit != nil { + cfg.EmailOnHit = *input.EmailOnHit + } + if input.AutoBanEnabled != nil { + cfg.AutoBanEnabled = *input.AutoBanEnabled + } + if input.BanThreshold != nil { + cfg.BanThreshold = *input.BanThreshold + } + if input.ViolationWindowHours != nil { + cfg.ViolationWindowHours = *input.ViolationWindowHours + } + if input.RetryCount != nil { + cfg.RetryCount = *input.RetryCount + } + if input.HitRetentionDays != nil { + cfg.HitRetentionDays = *input.HitRetentionDays + } + if input.NonHitRetentionDays != nil { + cfg.NonHitRetentionDays = *input.NonHitRetentionDays + } + if input.PreHashCheckEnabled != nil { + cfg.PreHashCheckEnabled = *input.PreHashCheckEnabled + } + if input.AllGroups != nil { + cfg.AllGroups = *input.AllGroups + } + if input.GroupIDs != nil { + cfg.GroupIDs = normalizeInt64IDs(*input.GroupIDs) + } + if input.RecordNonHits != nil { + cfg.RecordNonHits = *input.RecordNonHits + } + if input.ClearAPIKey { + cfg.APIKey = "" + cfg.APIKeys = []string{} + } else { + if input.APIKeys != nil { + cfg.APIKeys = normalizeModerationAPIKeys(*input.APIKeys) + cfg.APIKey = "" + } + if input.APIKey != nil && strings.TrimSpace(*input.APIKey) != "" { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, *input.APIKey)) + cfg.APIKey = "" + } + } + if err := s.validateConfig(ctx, cfg); err != nil { + return nil, err + } + cfg.normalize() + raw, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal content moderation config: %w", err) + } + if err := s.settingRepo.Set(ctx, SettingKeyContentModerationConfig, string(raw)); err != nil { + return nil, fmt.Errorf("save content moderation config: %w", err) + } + return s.configView(cfg), nil +} + +func (s *ContentModerationService) TestAPIKeys(ctx context.Context, input TestContentModerationAPIKeysInput) (*TestContentModerationAPIKeysResult, error) { + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + keys := normalizeModerationAPIKeys(input.APIKeys) + configured := false + if len(keys) == 0 { + keys = cfg.apiKeys() + configured = true + } + if strings.TrimSpace(input.BaseURL) != "" { + cfg.BaseURL = input.BaseURL + } + if strings.TrimSpace(input.Model) != "" { + cfg.Model = input.Model + } + if input.TimeoutMS > 0 { + cfg.TimeoutMS = input.TimeoutMS + } + cfg.normalize() + testInput, imageCount, err := buildModerationTestInput(input.Prompt, input.Images) + if err != nil { + return nil, err + } + auditOnly := contentModerationTestHasAuditInput(input.Prompt, input.Images) + if configured && auditOnly { + key, ok := s.nextUsableAPIKey(cfg) + if !ok { + return &TestContentModerationAPIKeysResult{ + Items: s.apiKeyStatuses(keys), + ImageCount: imageCount, + }, nil + } + keys = []string{key} + } + if len(keys) == 0 { + return &TestContentModerationAPIKeysResult{Items: []ContentModerationAPIKeyStatus{}, ImageCount: imageCount}, nil + } + items := make([]ContentModerationAPIKeyStatus, 0, len(keys)) + var auditResult *ContentModerationTestAuditResult + for idx, key := range keys { + start := time.Now() + httpStatus := 0 + result, err := s.callModerationOnceWithInput(ctx, cfg, key, testInput, &httpStatus) + latency := int(time.Since(start).Milliseconds()) + keyHash := moderationAPIKeyHash(key) + if err != nil { + s.markAPIKeyFailure(key, err.Error(), latency, httpStatus) + } else { + s.markAPIKeySuccess(key, latency, httpStatus) + if auditResult == nil { + auditResult = buildContentModerationTestAuditResult(result, cfg.Thresholds) + } + } + status := s.apiKeyStatusForHash(idx, keyHash, maskSecretTail(key), configured) + status.LastTested = true + items = append(items, status) + } + return &TestContentModerationAPIKeysResult{Items: items, AuditResult: auditResult, ImageCount: imageCount}, nil +} + +func (s *ContentModerationService) Check(ctx context.Context, input ContentModerationCheckInput) (*ContentModerationDecision, error) { + allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + if s == nil || s.settingRepo == nil || s.repo == nil { + slog.Info("content_moderation.skip_unavailable", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if !s.isRiskControlEnabled(ctx) { + slog.Info("content_moderation.skip_feature_disabled", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + cfg, err := s.loadConfig(ctx) + if err != nil { + slog.Warn("content_moderation.skip_config_load_failed", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "error", err) + return allow, nil + } + inScope := cfg.includesGroup(input.GroupID) + slog.Info("content_moderation.config_loaded", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "provider", input.Provider, + "protocol", input.Protocol, + "model", input.Model, + "enabled", cfg.Enabled, + "mode", cfg.Mode, + "all_groups", cfg.AllGroups, + "configured_group_ids", cfg.GroupIDs, + "in_scope", inScope, + "sample_rate", cfg.SampleRate, + "api_key_count", len(cfg.apiKeys()), + "pre_hash_check_enabled", cfg.PreHashCheckEnabled, + "record_non_hits", cfg.RecordNonHits) + if !cfg.Enabled { + slog.Info("content_moderation.skip_config_disabled", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if cfg.Mode == ContentModerationModeOff { + slog.Info("content_moderation.skip_mode_off", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if !inScope { + slog.Info("content_moderation.skip_group_out_of_scope", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "all_groups", cfg.AllGroups, + "configured_group_ids", cfg.GroupIDs) + return allow, nil + } + content := ExtractContentModerationInput(input.Protocol, input.Body) + if content.IsEmpty() { + slog.Info("content_moderation.skip_empty_input", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "body_bytes", len(input.Body)) + return allow, nil + } + content.Normalize() + slog.Info("content_moderation.input_extracted", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "text_runes", len([]rune(content.Text)), + "image_count", len(content.Images)) + hashText := content.Hash() + if cfg.PreHashCheckEnabled && s.hashCache != nil { + matched, err := s.hashCache.HasFlaggedInputHash(ctx, hashText) + if err != nil { + slog.Warn("content_moderation.hash_check_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) + } + if matched { + slog.Info("content_moderation.hash_block", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "input_hash", hashText) + message := cfg.BlockMessage + if message != "" { + message = fmt.Sprintf("%s(hash: %s)", message, hashText) + } + return &ContentModerationDecision{ + Allowed: false, + Blocked: true, + Flagged: true, + Message: message, + StatusCode: cfg.BlockStatus, + InputHash: hashText, + Action: ContentModerationActionHashBlock, + }, nil + } + } + if !cfg.shouldSample(hashText) { + slog.Info("content_moderation.skip_sample_rate", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "sample_rate", cfg.SampleRate) + return allow, nil + } + if len(cfg.apiKeys()) == 0 { + slog.Warn("content_moderation.skip_no_audit_api_keys", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol) + return allow, nil + } + if cfg.Mode == ContentModerationModeObserve { + slog.Info("content_moderation.enqueue_observe", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "queue_len", len(s.asyncQueue)) + s.enqueueAsync(input, cfg, content, hashText) + return allow, nil + } + + return s.checkSync(ctx, input, cfg, content, hashText, nil, true), nil +} + +func (s *ContentModerationService) checkSync(ctx context.Context, input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string, queueDelay *int, allowBlock bool) *ContentModerationDecision { + allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + start := time.Now() + result, err := s.callModeration(ctx, cfg, content.ModerationInput()) + latency := int(time.Since(start).Milliseconds()) + if err != nil { + slog.Warn("content_moderation.audit_api_failed", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "mode", cfg.Mode, + "allow_block", allowBlock, + "queue_delay_ms", queueDelay, + "latency_ms", latency, + "error", err) + if queueDelay != nil { + s.asyncErrors.Add(1) + } + if cfg.RecordNonHits { + log := s.buildLog(input, cfg, ContentModerationActionError, false, "", 0, nil, content.ExcerptText(), &latency, queueDelay, err.Error()) + _ = s.repo.CreateLog(ctx, log) + } + return allow + } + + flagged, highestCategory, highestScore := evaluateModerationScores(result.CategoryScores, cfg.Thresholds) + action := ContentModerationActionAllow + blocked := false + if allowBlock && flagged && cfg.Mode == ContentModerationModePreBlock { + action = ContentModerationActionBlock + blocked = true + } + slog.Info("content_moderation.audit_result", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "mode", cfg.Mode, + "allow_block", allowBlock, + "flagged", flagged, + "blocked", blocked, + "action", action, + "highest_category", highestCategory, + "highest_score", highestScore, + "latency_ms", latency, + "queue_delay_ms", queueDelay) + if flagged || cfg.RecordNonHits { + log := s.buildLog(input, cfg, action, flagged, highestCategory, highestScore, result.CategoryScores, content.ExcerptText(), &latency, queueDelay, "") + if flagged && s.hashCache != nil { + if err := s.hashCache.RecordFlaggedInputHash(ctx, hashText); err != nil { + slog.Warn("content_moderation.record_hash_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) + } + } + s.applyFlaggedSideEffects(ctx, cfg, log) + _ = s.repo.CreateLog(ctx, log) + } + if blocked { + return &ContentModerationDecision{ + Allowed: false, + Blocked: true, + Flagged: true, + Message: cfg.BlockMessage, + StatusCode: cfg.BlockStatus, + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: result.CategoryScores, + Action: action, + } + } + return &ContentModerationDecision{ + Allowed: true, + Flagged: flagged, + Message: "", + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: result.CategoryScores, + Action: action, + } +} + +func (s *ContentModerationService) enqueueAsync(input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string) { + if s == nil || s.asyncQueue == nil { + return + } + queueSize := defaultContentModerationQueueSize + if cfg != nil && cfg.QueueSize > 0 { + queueSize = cfg.QueueSize + } + if len(s.asyncQueue) >= queueSize { + slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint, "queue_size", queueSize) + s.asyncDropped.Add(1) + return + } + task := contentModerationTask{ + input: input, + content: content, + inputHash: hashText, + enqueuedAt: time.Now(), + } + select { + case s.asyncQueue <- task: + s.asyncEnqueued.Add(1) + default: + slog.Warn("content_moderation.async_queue_full", "user_id", input.UserID, "endpoint", input.Endpoint) + s.asyncDropped.Add(1) + } +} + +func (s *ContentModerationService) worker(id int) { + for { + ctx, cancel := context.WithTimeout(context.Background(), maxContentModerationTimeoutMS*time.Millisecond+10*time.Second) + cfg, err := s.loadConfig(ctx) + if err != nil || !cfg.Enabled || cfg.Mode == ContentModerationModeOff || len(cfg.apiKeys()) == 0 || id >= cfg.WorkerCount { + cancel() + time.Sleep(time.Second) + continue + } + task, ok := s.dequeueAsyncTask(ctx, time.Second) + if !ok { + cancel() + continue + } + func() { + defer cancel() + defer func() { + if r := recover(); r != nil { + slog.Error("content_moderation.worker_panic", "worker_id", id, "recover", r) + } + }() + if !cfg.includesGroup(task.input.GroupID) { + return + } + s.asyncActive.Add(1) + defer s.asyncActive.Add(-1) + queueDelay := int(time.Since(task.enqueuedAt).Milliseconds()) + _ = s.checkSync(ctx, task.input, cfg, task.content, task.inputHash, &queueDelay, false) + s.asyncProcessed.Add(1) + }() + } +} + +func (s *ContentModerationService) dequeueAsyncTask(ctx context.Context, idleWait time.Duration) (contentModerationTask, bool) { + var zero contentModerationTask + if s == nil || s.asyncQueue == nil { + return zero, false + } + if idleWait <= 0 { + idleWait = time.Second + } + timer := time.NewTimer(idleWait) + defer timer.Stop() + select { + case task, ok := <-s.asyncQueue: + return task, ok + case <-ctx.Done(): + return zero, false + case <-timer.C: + return zero, false + } +} + +func (s *ContentModerationService) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) { + if filter.Pagination.Page <= 0 { + filter.Pagination.Page = 1 + } + if filter.Pagination.PageSize <= 0 { + filter.Pagination.PageSize = 20 + } + if filter.Pagination.PageSize > 100 { + filter.Pagination.PageSize = 100 + } + if filter.Pagination.SortOrder == "" { + filter.Pagination.SortOrder = pagination.SortOrderDesc + } + return s.repo.ListLogs(ctx, filter) +} + +func (s *ContentModerationService) UnbanUser(ctx context.Context, userID int64) (*ContentModerationUnbanUserResult, error) { + if s == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_USER_REPOSITORY_UNAVAILABLE", "用户仓储不可用") + } + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER_ID", "用户 ID 无效") + } + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, infraerrors.NotFound("USER_NOT_FOUND", "用户不存在") + } + return nil, fmt.Errorf("get content moderation unban user: %w", err) + } + if user.Status != StatusActive { + user.Status = StatusActive + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, fmt.Errorf("update content moderation unban user: %w", err) + } + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return &ContentModerationUnbanUserResult{ + UserID: userID, + Status: StatusActive, + }, nil +} + +func (s *ContentModerationService) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (*ContentModerationDeleteHashResult, error) { + inputHash = normalizeContentModerationHash(inputHash) + if inputHash == "" { + return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_HASH", "风险输入哈希无效") + } + if s == nil || s.hashCache == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用") + } + deleted, err := s.hashCache.DeleteFlaggedInputHash(ctx, inputHash) + if err != nil { + return nil, fmt.Errorf("delete content moderation flagged hash: %w", err) + } + return &ContentModerationDeleteHashResult{ + InputHash: inputHash, + Deleted: deleted, + }, nil +} + +func (s *ContentModerationService) ClearFlaggedInputHashes(ctx context.Context) (*ContentModerationClearHashesResult, error) { + if s == nil || s.hashCache == nil { + return nil, infraerrors.InternalServer("CONTENT_MODERATION_HASH_CACHE_UNAVAILABLE", "内容审计哈希缓存不可用") + } + deleted, err := s.hashCache.ClearFlaggedInputHashes(ctx) + if err != nil { + return nil, fmt.Errorf("clear content moderation flagged hashes: %w", err) + } + return &ContentModerationClearHashesResult{Deleted: deleted}, nil +} + +func (s *ContentModerationService) GetStatus(ctx context.Context) (*ContentModerationRuntimeStatus, error) { + if s == nil { + return &ContentModerationRuntimeStatus{}, nil + } + cfg, err := s.loadConfig(ctx) + if err != nil { + return nil, err + } + riskEnabled := s.isRiskControlEnabled(ctx) + active := int(s.asyncActive.Load()) + if active < 0 { + active = 0 + } + if active > cfg.WorkerCount { + active = cfg.WorkerCount + } + queueLength := 0 + if s.asyncQueue != nil { + queueLength = len(s.asyncQueue) + } + queueUsage := 0.0 + if cfg.QueueSize > 0 { + queueUsage = float64(queueLength) * 100 / float64(cfg.QueueSize) + } + var flaggedHashCount int64 + if s.hashCache != nil { + if n, err := s.hashCache.CountFlaggedInputHashes(ctx); err == nil { + flaggedHashCount = n + } else { + slog.Warn("content_moderation.hash_count_failed", "error", err) + } + } + var lastCleanupAt *time.Time + if unix := s.lastCleanupUnix.Load(); unix > 0 { + t := time.Unix(unix, 0) + lastCleanupAt = &t + } + return &ContentModerationRuntimeStatus{ + Enabled: cfg.Enabled, + RiskControlEnabled: riskEnabled, + Mode: cfg.Mode, + WorkerCount: cfg.WorkerCount, + MaxWorkers: maxContentModerationWorkerCount, + ActiveWorkers: active, + IdleWorkers: cfg.WorkerCount - active, + QueueSize: cfg.QueueSize, + QueueLength: queueLength, + QueueUsagePercent: queueUsage, + Enqueued: s.asyncEnqueued.Load(), + Dropped: s.asyncDropped.Load(), + Processed: s.asyncProcessed.Load(), + Errors: s.asyncErrors.Load(), + APIKeyStatuses: s.apiKeyStatuses(cfg.apiKeys()), + FlaggedHashCount: flaggedHashCount, + LastCleanupAt: lastCleanupAt, + LastCleanupDeletedHit: s.lastCleanupDeletedHit.Load(), + LastCleanupDeletedNonHit: s.lastCleanupDeletedNonHit.Load(), + }, nil +} + +func (s *ContentModerationService) cleanupWorker() { + timer := time.NewTimer(contentModerationCleanupDelay) + defer timer.Stop() + for { + <-timer.C + s.runCleanupOnce() + timer.Reset(contentModerationCleanupInterval) + } +} + +func (s *ContentModerationService) runCleanupOnce() { + if s == nil || s.repo == nil || s.settingRepo == nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), contentModerationCleanupTimeout) + defer cancel() + cfg, err := s.loadConfig(ctx) + if err != nil { + slog.Warn("content_moderation.cleanup_load_config_failed", "error", err) + return + } + now := time.Now() + hitBefore := now.AddDate(0, 0, -cfg.HitRetentionDays) + nonHitBefore := now.AddDate(0, 0, -cfg.NonHitRetentionDays) + result, err := s.repo.CleanupExpiredLogs(ctx, hitBefore, nonHitBefore) + if err != nil { + slog.Warn("content_moderation.cleanup_failed", "error", err) + return + } + if result == nil { + return + } + s.lastCleanupUnix.Store(result.FinishedAt.Unix()) + s.lastCleanupDeletedHit.Store(result.DeletedHit) + s.lastCleanupDeletedNonHit.Store(result.DeletedNonHit) +} + +func (s *ContentModerationService) loadConfig(ctx context.Context) (*ContentModerationConfig, error) { + cfg := defaultContentModerationConfig() + raw, err := s.settingRepo.GetValue(ctx, SettingKeyContentModerationConfig) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + cfg.normalize() + return cfg, nil + } + return nil, fmt.Errorf("get content moderation config: %w", err) + } + if strings.TrimSpace(raw) == "" { + cfg.normalize() + return cfg, nil + } + if err := json.Unmarshal([]byte(raw), cfg); err != nil { + return nil, infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不是有效 JSON") + } + cfg.normalize() + return cfg, nil +} + +func (s *ContentModerationService) isRiskControlEnabled(ctx context.Context) bool { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyRiskControlEnabled) + if err != nil { + return false + } + return raw == "true" +} + +func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *ContentModerationConfig) error { + if cfg == nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_CONFIG", "内容审计配置不能为空") + } + cfg.normalize() + switch cfg.Mode { + case ContentModerationModeOff, ContentModerationModeObserve, ContentModerationModePreBlock: + default: + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODE", "内容审计模式无效") + } + if _, err := url.ParseRequestURI(cfg.BaseURL); err != nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BASE_URL", "OpenAI Base URL 无效") + } + if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间") + } + if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil { + for _, groupID := range cfg.GroupIDs { + if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_GROUP", fmt.Sprintf("审计分组不存在: %d", groupID)) + } + } + } + return nil +} + +func (s *ContentModerationService) callModeration(ctx context.Context, cfg *ContentModerationConfig, input any) (*moderationAPIResult, error) { + attempts := cfg.RetryCount + 1 + if attempts <= 0 { + attempts = 1 + } + if attempts > maxContentModerationRetryCount+1 { + attempts = maxContentModerationRetryCount + 1 + } + var lastErr error + for attempt := 0; attempt < attempts; attempt++ { + key, ok := s.nextUsableAPIKey(cfg) + if !ok { + lastErr = errors.New("no moderation api key available") + break + } + start := time.Now() + httpStatus := 0 + result, err := s.callModerationOnceWithInput(ctx, cfg, key, input, &httpStatus) + latency := int(time.Since(start).Milliseconds()) + if err == nil { + s.markAPIKeySuccess(key, latency, httpStatus) + return result, nil + } + s.markAPIKeyFailure(key, err.Error(), latency, httpStatus) + lastErr = err + if attempt == attempts-1 { + break + } + wait := time.Duration(100*(attempt+1)) * time.Millisecond + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } + return nil, lastErr +} + +func (s *ContentModerationService) callModerationOnceWithInput(ctx context.Context, cfg *ContentModerationConfig, apiKey string, input any, httpStatus *int) (*moderationAPIResult, error) { + base := strings.TrimRight(cfg.BaseURL, "/") + endpoint, err := url.JoinPath(base, "/v1/moderations") + if err != nil { + return nil, err + } + payload := moderationAPIRequest{ + Model: cfg.Model, + Input: input, + } + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + timeout := time.Duration(cfg.TimeoutMS) * time.Millisecond + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, endpoint, bytes.NewReader(raw)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Content-Type", "application/json") + + client := s.httpClient + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if httpStatus != nil { + *httpStatus = resp.StatusCode + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("moderation api status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var out moderationAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + if len(out.Results) == 0 { + return nil, errors.New("moderation api returned empty results") + } + return &out.Results[0], nil +} + +func (s *ContentModerationService) buildLog(input ContentModerationCheckInput, cfg *ContentModerationConfig, action string, flagged bool, highestCategory string, highestScore float64, scores map[string]float64, text string, latency *int, queueDelay *int, errText string) *ContentModerationLog { + var userID *int64 + if input.UserID > 0 { + userID = &input.UserID + } + var apiKeyID *int64 + if input.APIKeyID > 0 { + apiKeyID = &input.APIKeyID + } + return &ContentModerationLog{ + RequestID: input.RequestID, + UserID: userID, + UserEmail: input.UserEmail, + APIKeyID: apiKeyID, + APIKeyName: input.APIKeyName, + GroupID: cloneInt64Ptr(input.GroupID), + GroupName: input.GroupName, + Endpoint: input.Endpoint, + Provider: input.Provider, + Model: input.Model, + Mode: cfg.Mode, + Action: action, + Flagged: flagged, + HighestCategory: highestCategory, + HighestScore: highestScore, + CategoryScores: cloneFloatMap(scores), + ThresholdSnapshot: cloneFloatMap(cfg.Thresholds), + InputExcerpt: trimRunes(redactContentModerationSecrets(text), maxModerationExcerptRunes), + UpstreamLatencyMS: latency, + QueueDelayMS: queueDelay, + Error: errText, + } +} + +func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) { + if s == nil || cfg == nil || log == nil || !log.Flagged || log.UserID == nil || *log.UserID <= 0 { + return + } + count := 1 + if s.repo != nil && cfg.ViolationWindowHours > 0 { + since := time.Now().Add(-time.Duration(cfg.ViolationWindowHours) * time.Hour) + if n, err := s.repo.CountFlaggedByUserSince(ctx, *log.UserID, since); err == nil { + count = n + 1 + } + } + log.ViolationCount = count + autoBanJustApplied := false + if cfg.AutoBanEnabled && cfg.BanThreshold > 0 && count >= cfg.BanThreshold && s.userRepo != nil { + user, err := s.userRepo.GetByID(ctx, *log.UserID) + if err != nil { + slog.Warn("content_moderation.ban_get_user_failed", "user_id", *log.UserID, "error", err) + return + } + if user.Status != StatusDisabled { + user.Status = StatusDisabled + if err := s.userRepo.Update(ctx, user); err != nil { + slog.Warn("content_moderation.ban_update_user_failed", "user_id", *log.UserID, "error", err) + return + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, *log.UserID) + } + autoBanJustApplied = true + } + log.AutoBanned = true + } + + if s.emailService == nil || strings.TrimSpace(log.UserEmail) == "" { + return + } + emailSent := false + if cfg.EmailOnHit { + if err := s.sendViolationEmail(ctx, cfg, log); err != nil { + slog.Warn("content_moderation.email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err) + } else { + emailSent = true + } + } + if autoBanJustApplied { + if err := s.sendAccountDisabledEmail(ctx, cfg, log); err != nil { + slog.Warn("content_moderation.ban_email_failed", "user_id", *log.UserID, "email", log.UserEmail, "error", err) + } else { + emailSent = true + } + } + log.EmailSent = emailSent +} + +func (s *ContentModerationService) sendViolationEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { + siteName := s.siteName(ctx) + subject := fmt.Sprintf("[%s] 账户风控提醒 / Risk Control Notice", sanitizeEmailHeader(siteName)) + body := buildContentModerationViolationEmailBody(siteName, log, cfg) + return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) +} + +func (s *ContentModerationService) sendAccountDisabledEmail(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) error { + siteName := s.siteName(ctx) + subject := fmt.Sprintf("[%s] 账户已被禁用 / Account Disabled", sanitizeEmailHeader(siteName)) + body := buildContentModerationAccountDisabledEmailBody(siteName, log, cfg) + return s.emailService.SendEmail(ctx, log.UserEmail, subject, body) +} + +func (s *ContentModerationService) siteName(ctx context.Context) string { + if s == nil || s.settingRepo == nil { + return "Sub2API" + } + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || strings.TrimSpace(name) == "" { + return "Sub2API" + } + return strings.TrimSpace(name) +} + +func defaultContentModerationConfig() *ContentModerationConfig { + return &ContentModerationConfig{ + Enabled: false, + Mode: ContentModerationModePreBlock, + BaseURL: defaultContentModerationBaseURL, + Model: defaultContentModerationModel, + TimeoutMS: defaultContentModerationTimeoutMS, + SampleRate: 100, + AllGroups: true, + GroupIDs: []int64{}, + RecordNonHits: false, + Thresholds: ContentModerationDefaultThresholds(), + WorkerCount: defaultContentModerationWorkerCount, + QueueSize: defaultContentModerationQueueSize, + BlockStatus: defaultContentModerationBlockHTTPStatus, + BlockMessage: defaultContentModerationBlockMessage, + EmailOnHit: true, + AutoBanEnabled: true, + BanThreshold: defaultContentModerationBanThreshold, + ViolationWindowHours: defaultContentModerationViolationWindowHours, + RetryCount: defaultContentModerationRetryCount, + HitRetentionDays: defaultContentModerationHitRetentionDays, + NonHitRetentionDays: defaultContentModerationNonHitRetentionDays, + PreHashCheckEnabled: false, + } +} + +func (cfg *ContentModerationConfig) normalize() { + if cfg.APIKey != "" { + cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, cfg.APIKey)) + cfg.APIKey = "" + } else { + cfg.APIKeys = normalizeModerationAPIKeys(cfg.APIKeys) + } + if cfg.Mode == "" { + cfg.Mode = ContentModerationModePreBlock + } + if cfg.BaseURL == "" { + cfg.BaseURL = defaultContentModerationBaseURL + } + cfg.BaseURL = strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if cfg.Model == "" { + cfg.Model = defaultContentModerationModel + } + cfg.Model = strings.TrimSpace(cfg.Model) + if cfg.TimeoutMS <= 0 { + cfg.TimeoutMS = defaultContentModerationTimeoutMS + } + if cfg.TimeoutMS > maxContentModerationTimeoutMS { + cfg.TimeoutMS = maxContentModerationTimeoutMS + } + if cfg.SampleRate < 0 { + cfg.SampleRate = 0 + } + if cfg.SampleRate > 100 { + cfg.SampleRate = 100 + } + if cfg.WorkerCount <= 0 { + cfg.WorkerCount = defaultContentModerationWorkerCount + } + if cfg.WorkerCount > maxContentModerationWorkerCount { + cfg.WorkerCount = maxContentModerationWorkerCount + } + if cfg.QueueSize <= 0 { + cfg.QueueSize = defaultContentModerationQueueSize + } + if cfg.QueueSize > maxContentModerationQueueSize { + cfg.QueueSize = maxContentModerationQueueSize + } + if strings.TrimSpace(cfg.BlockMessage) == "" { + cfg.BlockMessage = defaultContentModerationBlockMessage + } + cfg.BlockMessage = strings.TrimSpace(cfg.BlockMessage) + if cfg.BlockStatus <= 0 { + cfg.BlockStatus = defaultContentModerationBlockHTTPStatus + } + if cfg.BanThreshold <= 0 { + cfg.BanThreshold = defaultContentModerationBanThreshold + } + if cfg.ViolationWindowHours <= 0 { + cfg.ViolationWindowHours = defaultContentModerationViolationWindowHours + } + if cfg.RetryCount < 0 { + cfg.RetryCount = 0 + } + if cfg.RetryCount > maxContentModerationRetryCount { + cfg.RetryCount = maxContentModerationRetryCount + } + if cfg.HitRetentionDays <= 0 { + cfg.HitRetentionDays = defaultContentModerationHitRetentionDays + } + if cfg.HitRetentionDays > maxContentModerationRetentionDays { + cfg.HitRetentionDays = maxContentModerationRetentionDays + } + if cfg.NonHitRetentionDays <= 0 { + cfg.NonHitRetentionDays = defaultContentModerationNonHitRetentionDays + } + if cfg.NonHitRetentionDays > maxContentModerationNonHitRetentionDays { + cfg.NonHitRetentionDays = maxContentModerationNonHitRetentionDays + } + cfg.GroupIDs = normalizeInt64IDs(cfg.GroupIDs) + cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds) +} + +func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { + if cfg.AllGroups { + return true + } + if groupID == nil { + return false + } + for _, id := range cfg.GroupIDs { + if id == *groupID { + return true + } + } + return false +} + +func contentModerationLogGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func (cfg *ContentModerationConfig) shouldSample(hashText string) bool { + if cfg.SampleRate >= 100 { + return true + } + if cfg.SampleRate <= 0 { + return false + } + raw, err := hex.DecodeString(hashText) + if err != nil || len(raw) < 2 { + return true + } + return int(binary.BigEndian.Uint16(raw[:2])%100) < cfg.SampleRate +} + +func (cfg *ContentModerationConfig) apiKeys() []string { + if cfg == nil { + return nil + } + return normalizeModerationAPIKeys(cfg.APIKeys) +} + +func (s *ContentModerationService) nextUsableAPIKey(cfg *ContentModerationConfig) (string, bool) { + keys := cfg.apiKeys() + if len(keys) == 0 { + return "", false + } + now := time.Now() + for i := 0; i < len(keys); i++ { + idx := int(s.apiKeyCursor.Add(1)-1) % len(keys) + key := keys[idx] + if !s.isAPIKeyFrozen(key, now) { + return key, true + } + } + return "", false +} + +func (s *ContentModerationService) isAPIKeyFrozen(key string, now time.Time) bool { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return false + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + return state != nil && state.FrozenUntil.After(now) +} + +func (s *ContentModerationService) markAPIKeySuccess(key string, latencyMS int, httpStatus int) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + state.FailureCount = 0 + state.SuccessCount++ + state.LastError = "" + state.LastCheckedAt = time.Now() + state.FrozenUntil = time.Time{} + state.LastLatencyMS = latencyMS + state.LastHTTPStatus = httpStatus + state.LastTested = true +} + +func (s *ContentModerationService) markAPIKeyFailure(key string, errText string, latencyMS int, httpStatus int) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + state.FailureCount++ + state.LastError = trimRunes(errText, 180) + state.LastCheckedAt = time.Now() + state.LastLatencyMS = latencyMS + state.LastHTTPStatus = httpStatus + state.LastTested = true + if state.FailureCount >= contentModerationKeyFailureFreezeThreshold { + state.FrozenUntil = time.Now().Add(contentModerationKeyFreezeDuration) + } +} + +func (s *ContentModerationService) ensureAPIKeyHealthLocked(hash string, masked string) *contentModerationKeyHealth { + if s.keyHealth == nil { + s.keyHealth = make(map[string]*contentModerationKeyHealth) + } + state := s.keyHealth[hash] + if state == nil { + state = &contentModerationKeyHealth{Hash: hash} + s.keyHealth[hash] = state + } + if strings.TrimSpace(masked) != "" { + state.Masked = masked + } + return state +} + +func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *ContentModerationConfigView { + keys := cfg.apiKeys() + masks := make([]string, 0, len(keys)) + for _, key := range keys { + masks = append(masks, maskSecretTail(key)) + } + apiKeyMasked := "" + if len(masks) > 0 { + apiKeyMasked = masks[0] + } + return &ContentModerationConfigView{ + Enabled: cfg.Enabled, + Mode: cfg.Mode, + BaseURL: cfg.BaseURL, + Model: cfg.Model, + APIKeyConfigured: len(keys) > 0, + APIKeyMasked: apiKeyMasked, + APIKeyCount: len(keys), + APIKeyMasks: masks, + APIKeyStatuses: s.apiKeyStatuses(keys), + TimeoutMS: cfg.TimeoutMS, + SampleRate: cfg.SampleRate, + AllGroups: cfg.AllGroups, + GroupIDs: append([]int64(nil), cfg.GroupIDs...), + RecordNonHits: cfg.RecordNonHits, + WorkerCount: cfg.WorkerCount, + QueueSize: cfg.QueueSize, + BlockStatus: cfg.BlockStatus, + BlockMessage: cfg.BlockMessage, + EmailOnHit: cfg.EmailOnHit, + AutoBanEnabled: cfg.AutoBanEnabled, + BanThreshold: cfg.BanThreshold, + ViolationWindowHours: cfg.ViolationWindowHours, + RetryCount: cfg.RetryCount, + HitRetentionDays: cfg.HitRetentionDays, + NonHitRetentionDays: cfg.NonHitRetentionDays, + PreHashCheckEnabled: cfg.PreHashCheckEnabled, + } +} + +func (s *ContentModerationService) apiKeyStatuses(keys []string) []ContentModerationAPIKeyStatus { + out := make([]ContentModerationAPIKeyStatus, 0, len(keys)) + for idx, key := range keys { + out = append(out, s.apiKeyStatusForHash(idx, moderationAPIKeyHash(key), maskSecretTail(key), true)) + } + return out +} + +func (s *ContentModerationService) apiKeyStatusForHash(index int, hash string, masked string, configured bool) ContentModerationAPIKeyStatus { + status := ContentModerationAPIKeyStatus{ + Index: index, + KeyHash: hash, + Masked: masked, + Status: "unknown", + Configured: configured, + } + if hash == "" || s == nil { + return status + } + now := time.Now() + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + if state == nil { + return status + } + status.FailureCount = state.FailureCount + status.SuccessCount = state.SuccessCount + status.LastError = state.LastError + status.LastLatencyMS = state.LastLatencyMS + status.LastHTTPStatus = state.LastHTTPStatus + status.LastTested = state.LastTested + if !state.LastCheckedAt.IsZero() { + t := state.LastCheckedAt + status.LastCheckedAt = &t + } + if state.FrozenUntil.After(now) { + t := state.FrozenUntil + status.FrozenUntil = &t + status.Status = "frozen" + return status + } + if state.LastError != "" { + status.Status = "error" + return status + } + if state.SuccessCount > 0 || state.LastTested { + status.Status = "ok" + } + return status +} + +func moderationAPIKeyHash(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func buildModerationTestInput(prompt string, images []string) (any, int, error) { + prompt = trimRunes(normalizeContentModerationText(prompt), maxModerationInputRunes) + normalizedImages := make([]string, 0, len(images)) + for _, image := range images { + image = strings.TrimSpace(image) + if image == "" { + continue + } + if len(normalizedImages) >= maxContentModerationTestImages { + return nil, 0, infraerrors.BadRequest("TOO_MANY_MODERATION_TEST_IMAGES", fmt.Sprintf("最多上传 %d 张测试图片", maxContentModerationTestImages)) + } + if err := validateModerationTestImageDataURL(image); err != nil { + return nil, 0, err + } + normalizedImages = append(normalizedImages, image) + } + if prompt == "" && len(normalizedImages) == 0 { + return "hello", 0, nil + } + if len(normalizedImages) == 0 { + return prompt, 0, nil + } + parts := make([]moderationAPIInputPart, 0, len(normalizedImages)+1) + if prompt != "" { + parts = append(parts, moderationAPIInputPart{Type: "text", Text: prompt}) + } + for _, image := range normalizedImages { + parts = append(parts, moderationAPIInputPart{ + Type: "image_url", + ImageURL: &moderationAPIImageURLRef{URL: image}, + }) + } + return parts, len(normalizedImages), nil +} + +func contentModerationTestHasAuditInput(prompt string, images []string) bool { + if normalizeContentModerationText(prompt) != "" { + return true + } + for _, image := range images { + if strings.TrimSpace(image) != "" { + return true + } + } + return false +} + +func validateModerationTestImageDataURL(value string) error { + if len(value) > maxContentModerationTestImageDataURLBytes { + return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB") + } + if !strings.HasPrefix(value, "data:image/") { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 data:image/* base64") + } + parts := strings.SplitN(value, ",", 2) + if len(parts) != 2 || !strings.Contains(parts[0], ";base64") { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片必须是 base64 data URL") + } + raw, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return infraerrors.BadRequest("INVALID_MODERATION_TEST_IMAGE", "测试图片 base64 无效") + } + if len(raw) > maxContentModerationTestImageBytes { + return infraerrors.BadRequest("MODERATION_TEST_IMAGE_TOO_LARGE", "测试图片不能超过 8MB") + } + return nil +} + +func buildContentModerationTestAuditResult(result *moderationAPIResult, thresholds map[string]float64) *ContentModerationTestAuditResult { + if result == nil { + return nil + } + scores := make(map[string]float64, len(result.CategoryScores)) + for category, score := range result.CategoryScores { + scores[category] = score + } + thresholdSnapshot := mergeContentModerationThresholds(ContentModerationDefaultThresholds(), thresholds) + flagged, highestCategory, highestScore := evaluateModerationScores(scores, thresholdSnapshot) + compositeScore := highestScore + return &ContentModerationTestAuditResult{ + Flagged: flagged, + HighestCategory: highestCategory, + HighestScore: highestScore, + CompositeScore: compositeScore, + CategoryScores: scores, + Thresholds: thresholdSnapshot, + } +} + +type moderationAPIRequest struct { + Model string `json:"model"` + Input any `json:"input"` +} + +type moderationAPIInputPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *moderationAPIImageURLRef `json:"image_url,omitempty"` +} + +type moderationAPIImageURLRef struct { + URL string `json:"url"` +} + +type moderationAPIResponse struct { + Results []moderationAPIResult `json:"results"` +} + +type moderationAPIResult struct { + Flagged bool `json:"flagged"` + CategoryScores map[string]float64 `json:"category_scores"` +} + +func evaluateModerationScores(scores map[string]float64, thresholds map[string]float64) (bool, string, float64) { + flagged := false + highestCategory := "" + highestScore := 0.0 + for _, category := range contentModerationCategoryOrder { + score := scores[category] + if score > highestScore || highestCategory == "" { + highestScore = score + highestCategory = category + } + if score >= thresholds[category] { + flagged = true + } + } + for category, score := range scores { + if score > highestScore || highestCategory == "" { + highestScore = score + highestCategory = category + } + } + return flagged, highestCategory, highestScore +} + +func mergeContentModerationThresholds(base map[string]float64, override map[string]float64) map[string]float64 { + out := cloneFloatMap(base) + if out == nil { + out = map[string]float64{} + } + for _, category := range contentModerationCategoryOrder { + if v, ok := override[category]; ok { + if v < 0 { + v = 0 + } + if v > 1 { + v = 1 + } + out[category] = v + } + } + return out +} + +func normalizeInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return []int64{} + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + return out +} + +func normalizeModerationAPIKeys(keys []string) []string { + if len(keys) == 0 { + return []string{} + } + seen := make(map[string]struct{}, len(keys)) + out := make([]string, 0, len(keys)) + for _, key := range keys { + key = strings.TrimSpace(key) + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, key) + } + return out +} + +func normalizeContentModerationHash(inputHash string) string { + inputHash = strings.ToLower(strings.TrimSpace(inputHash)) + if len(inputHash) != sha256.Size*2 { + return "" + } + if _, err := hex.DecodeString(inputHash); err != nil { + return "" + } + return inputHash +} + +func cloneFloatMap(in map[string]float64) map[string]float64 { + if in == nil { + return map[string]float64{} + } + out := make(map[string]float64, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneInt64Ptr(in *int64) *int64 { + if in == nil { + return nil + } + v := *in + return &v +} + +func trimRunes(text string, max int) string { + if max <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= max { + return text + } + return string(runes[:max]) +} + +func maskSecretTail(secret string) string { + secret = strings.TrimSpace(secret) + if secret == "" { + return "" + } + if len(secret) <= 4 { + return "****" + } + return strings.Repeat("*", 8) + secret[len(secret)-4:] +} diff --git a/backend/internal/service/content_moderation_email.go b/backend/internal/service/content_moderation_email.go new file mode 100644 index 00000000..e462ff88 --- /dev/null +++ b/backend/internal/service/content_moderation_email.go @@ -0,0 +1,117 @@ +package service + +import ( + "fmt" + "html" + "strings" + "time" +) + +func buildContentModerationViolationEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string { + if log == nil { + return "" + } + userName := strings.TrimSpace(log.UserEmail) + if userName == "" && log.UserID != nil { + userName = fmt.Sprintf("UID %d", *log.UserID) + } + threshold := cfg.BanThreshold + if threshold <= 0 { + threshold = defaultContentModerationBanThreshold + } + statusBlock := "" + if log.AutoBanned { + statusBlock = `
账户当前处于封禁状态,所有 API 请求将被拒绝
` + } + return fmt.Sprintf(` + + +
+
+
+
Risk Control / 风控提醒
+

账户触发内容审计规则

+

尊敬的用户 %s,您的 API 请求在内容审计中触发平台风控策略。详情如下。

+
+

触发详情

+ + + + + + +
触发时间%s
触发来源内容审核
所属分组%s
命中类别%s / %.3f
累计触发次数%d 次(阈值 %d)
+
+ %s +

此邮件由 %s 自动发送,请勿回复。

+
+
+ +`, + html.EscapeString(userName), + html.EscapeString(time.Now().Format("2006-01-02 15:04:05")), + html.EscapeString(defaultContentModerationString(log.GroupName, "-")), + html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")), + log.HighestScore, + log.ViolationCount, + threshold, + statusBlock, + html.EscapeString(siteName), + ) +} + +func buildContentModerationAccountDisabledEmailBody(siteName string, log *ContentModerationLog, cfg *ContentModerationConfig) string { + if log == nil { + return "" + } + userName := strings.TrimSpace(log.UserEmail) + if userName == "" && log.UserID != nil { + userName = fmt.Sprintf("UID %d", *log.UserID) + } + threshold := cfg.BanThreshold + if threshold <= 0 { + threshold = defaultContentModerationBanThreshold + } + return fmt.Sprintf(` + + +
+
+
+
Risk Control / 账户封禁
+

账户已被自动禁用

+

尊敬的用户 %s,您的账户在计数周期内多次触发平台风控策略,系统已自动禁用该账户。详情如下。

+
+

封禁详情

+ + + + + + +
封禁时间%s
触发来源内容审核
所属分组%s
命中类别%s / %.3f
累计触发次数%d 次(阈值 %d)
+
+
账户当前处于封禁状态,所有 API 请求将被拒绝
+

如需申诉或恢复账号,请联系平台管理员处理。

+

此邮件由 %s 自动发送,请勿回复。

+
+
+ +`, + html.EscapeString(userName), + html.EscapeString(time.Now().Format("2006-01-02 15:04:05")), + html.EscapeString(defaultContentModerationString(log.GroupName, "-")), + html.EscapeString(defaultContentModerationString(log.HighestCategory, "-")), + log.HighestScore, + log.ViolationCount, + threshold, + html.EscapeString(siteName), + ) +} + +func defaultContentModerationString(value string, fallback string) string { + if strings.TrimSpace(value) == "" { + return fallback + } + return strings.TrimSpace(value) +} diff --git a/backend/internal/service/content_moderation_input.go b/backend/internal/service/content_moderation_input.go new file mode 100644 index 00000000..a0b3b663 --- /dev/null +++ b/backend/internal/service/content_moderation_input.go @@ -0,0 +1,307 @@ +package service + +import ( + "fmt" + "strings" + + "github.com/tidwall/gjson" +) + +func ExtractContentModerationText(protocol string, body []byte) string { + return ExtractContentModerationInput(protocol, body).Text +} + +func ExtractContentModerationInput(protocol string, body []byte) ContentModerationInput { + if len(body) == 0 || !gjson.ValidBytes(body) { + return ContentModerationInput{} + } + var parts []string + var images []string + switch protocol { + case ContentModerationProtocolAnthropicMessages: + collectLastAnthropicUserMessage(gjson.GetBytes(body, "messages"), &parts, &images) + case ContentModerationProtocolOpenAIChat: + collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images) + case ContentModerationProtocolOpenAIResponses: + collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images) + case ContentModerationProtocolGemini: + collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images) + case ContentModerationProtocolOpenAIImages: + addModerationText(&parts, gjson.GetBytes(body, "prompt").String()) + collectContentValue(gjson.GetBytes(body, "images"), &parts, &images) + default: + collectLastResponsesInput(gjson.GetBytes(body, "input"), &parts, &images) + collectLastRoleMessage(gjson.GetBytes(body, "messages"), "user", &parts, &images) + collectLastGeminiContent(gjson.GetBytes(body, "contents"), &parts, &images) + } + out := ContentModerationInput{ + Text: normalizeContentModerationText(strings.Join(parts, "\n")), + Images: normalizeModerationImages(images), + } + out.Normalize() + return out +} + +func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, images *[]string) { + if !messages.IsArray() { + return + } + var lastParts []string + var lastImages []string + messages.ForEach(func(_, msg gjson.Result) bool { + if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role { + var candidate []string + var candidateImages []string + collectContentValue(msg.Get("content"), &candidate, &candidateImages) + if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { + lastParts = candidate + lastImages = candidateImages + } + } + return true + }) + *parts = append(*parts, lastParts...) + *images = append(*images, lastImages...) +} + +func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) { + if !messages.IsArray() { + return + } + var lastParts []string + var lastImages []string + messages.ForEach(func(_, msg gjson.Result) bool { + if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" { + var candidate []string + var candidateImages []string + collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages) + if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { + lastParts = candidate + lastImages = candidateImages + } + } + return true + }) + *parts = append(*parts, lastParts...) + *images = append(*images, lastImages...) +} + +func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) { + switch { + case !value.Exists(): + return + case value.Type == gjson.String: + if !isAnthropicSystemReminderText(value.String()) { + addModerationText(parts, value.String()) + } + case value.IsArray(): + value.ForEach(func(_, item gjson.Result) bool { + collectAnthropicUserContentValue(item, parts, images) + return true + }) + case value.IsObject(): + typ := strings.ToLower(strings.TrimSpace(value.Get("type").String())) + switch typ { + case "", "text", "input_text", "message": + if value.Get("text").Exists() && !isAnthropicSystemReminderText(value.Get("text").String()) { + addModerationText(parts, value.Get("text").String()) + } + if value.Get("content").Exists() { + collectAnthropicUserContentValue(value.Get("content"), parts, images) + } + case "image_url", "input_image", "image": + collectContentValue(value, parts, images) + } + } +} + +func isAnthropicSystemReminderText(text string) bool { + return strings.HasPrefix(strings.TrimSpace(text), "") +} + +func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]string) { + switch { + case !input.Exists(): + return + case input.Type == gjson.String: + addModerationText(parts, input.String()) + case input.IsArray(): + var last gjson.Result + input.ForEach(func(_, item gjson.Result) bool { + if isResponsesUserTextItem(item) { + last = item + } + return true + }) + if last.Exists() { + collectContentValue(last.Get("content"), parts, images) + if last.Get("type").String() == "input_text" || last.Get("text").Exists() { + collectContentValue(last, parts, images) + } + } + case input.IsObject(): + if isResponsesUserTextItem(input) { + collectContentValue(input.Get("content"), parts, images) + if input.Get("type").String() == "input_text" || input.Get("text").Exists() { + collectContentValue(input, parts, images) + } + } + } +} + +func isResponsesUserTextItem(item gjson.Result) bool { + role := strings.ToLower(strings.TrimSpace(item.Get("role").String())) + if role == "user" { + return responseItemHasModerationText(item) + } + if role != "" { + return false + } + return responseItemHasModerationText(item) +} + +func responseItemHasModerationText(item gjson.Result) bool { + var parts []string + var images []string + collectContentValue(item.Get("content"), &parts, &images) + if item.Get("type").String() == "input_text" || item.Get("text").Exists() { + collectContentValue(item, &parts, &images) + } + return normalizeContentModerationText(strings.Join(parts, "\n")) != "" || len(images) > 0 +} + +func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]string) { + if !contents.IsArray() { + return + } + var lastParts []string + var lastImages []string + contents.ForEach(func(_, content gjson.Result) bool { + role := strings.ToLower(strings.TrimSpace(content.Get("role").String())) + if role == "" || role == "user" { + var candidate []string + var candidateImages []string + if arr := content.Get("parts"); arr.IsArray() { + arr.ForEach(func(_, part gjson.Result) bool { + addModerationText(&candidate, part.Get("text").String()) + addGeminiModerationImage(&candidateImages, part) + return true + }) + } + if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { + lastParts = candidate + lastImages = candidateImages + } + } + return true + }) + *parts = append(*parts, lastParts...) + *images = append(*images, lastImages...) +} + +func collectContentValue(value gjson.Result, parts *[]string, images *[]string) { + switch { + case !value.Exists(): + return + case value.Type == gjson.String: + addModerationText(parts, value.String()) + case value.IsArray(): + value.ForEach(func(_, item gjson.Result) bool { + collectContentValue(item, parts, images) + return true + }) + case value.IsObject(): + typ := strings.ToLower(strings.TrimSpace(value.Get("type").String())) + addModerationImage(images, value.Get("image_url.url").String()) + addModerationImage(images, value.Get("image_url").String()) + addModerationImage(images, value.Get("url").String()) + addModerationImageData(images, value.Get("source.media_type").String(), value.Get("source.data").String()) + addModerationImageData(images, value.Get("source.mediaType").String(), value.Get("source.data").String()) + addModerationImageData(images, value.Get("media_type").String(), value.Get("data").String()) + addModerationImageData(images, value.Get("mime_type").String(), value.Get("data").String()) + addModerationImageData(images, value.Get("mimeType").String(), value.Get("data").String()) + addModerationImage(images, value.Get("source.data").String()) + addModerationImage(images, value.Get("data").String()) + addModerationImage(images, value.Get("base64").String()) + switch typ { + case "", "text", "input_text", "message": + if value.Get("text").Exists() { + addModerationText(parts, value.Get("text").String()) + } + if value.Get("content").Exists() { + collectContentValue(value.Get("content"), parts, images) + } + case "image_url", "input_image", "image": + } + } +} + +func addGeminiModerationImage(images *[]string, part gjson.Result) { + if inlineData := part.Get("inline_data"); inlineData.IsObject() { + mimeType := strings.TrimSpace(inlineData.Get("mime_type").String()) + data := strings.TrimSpace(inlineData.Get("data").String()) + if mimeType != "" && data != "" { + addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data)) + } + } + if inlineData := part.Get("inlineData"); inlineData.IsObject() { + mimeType := strings.TrimSpace(inlineData.Get("mimeType").String()) + data := strings.TrimSpace(inlineData.Get("data").String()) + if mimeType != "" && data != "" { + addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data)) + } + } + addModerationImage(images, part.Get("file_data.file_uri").String()) + addModerationImage(images, part.Get("fileData.fileUri").String()) +} + +func addModerationImageData(images *[]string, mimeType string, data string) { + mimeType = strings.TrimSpace(mimeType) + data = strings.TrimSpace(data) + if mimeType == "" || data == "" { + return + } + addModerationImage(images, fmt.Sprintf("data:%s;base64,%s", mimeType, data)) +} + +func addModerationImage(images *[]string, image string) { + image = strings.TrimSpace(image) + if image == "" { + return + } + if strings.HasPrefix(image, "data:") || strings.HasPrefix(image, "http://") || strings.HasPrefix(image, "https://") { + *images = append(*images, image) + } +} + +func normalizeModerationImages(images []string) []string { + out := make([]string, 0, len(images)) + seen := make(map[string]struct{}, len(images)) + for _, image := range images { + image = strings.TrimSpace(image) + if image == "" { + continue + } + if _, ok := seen[image]; ok { + continue + } + seen[image] = struct{}{} + out = append(out, image) + } + return out +} + +func addModerationText(parts *[]string, text string) { + text = strings.TrimSpace(text) + if text == "" { + return + } + if strings.Contains(text, "") { + return + } + *parts = append(*parts, text) +} + +func normalizeContentModerationText(text string) string { + return strings.Join(strings.Fields(strings.TrimSpace(text)), " ") +} diff --git a/backend/internal/service/content_moderation_redact.go b/backend/internal/service/content_moderation_redact.go new file mode 100644 index 00000000..548cbeab --- /dev/null +++ b/backend/internal/service/content_moderation_redact.go @@ -0,0 +1,36 @@ +package service + +import ( + "regexp" + "strings" +) + +var contentModerationSecretPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)\b((?:api[_-]?key|apikey|access[_-]?token|refresh[_-]?token|id[_-]?token|session[_-]?token|token|session|cookie|set[_-]?cookie|authorization|bearer|password|passwd|pwd|secret|client[_-]?secret|private[_-]?key)\s*[:=]\s*)(["']?)[^"'\s,;,。;、]{6,}`), + regexp.MustCompile(`(?i)\b(Bearer\s+)[A-Za-z0-9._~+/=-]{12,}`), + regexp.MustCompile(`\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}\b`), + regexp.MustCompile(`(?i)\b(?:sk|sk-proj|sk-ant|sess|rk|pk|ak|api|key|token|secret)[_-][A-Za-z0-9._~+/=-]{12,}\b`), + regexp.MustCompile(`\b[0-9a-fA-F]{32,}\b`), + regexp.MustCompile(`\b[A-Za-z0-9_-]{48,}\b`), + regexp.MustCompile(`\b[A-Za-z0-9+/]{48,}={0,2}\b`), + regexp.MustCompile(`\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b`), +} + +func redactContentModerationSecrets(text string) string { + text = strings.TrimSpace(text) + if text == "" { + return "" + } + out := text + for idx, pattern := range contentModerationSecretPatterns { + switch idx { + case 0: + out = pattern.ReplaceAllString(out, `${1}${2}[已脱敏]`) + case 1: + out = pattern.ReplaceAllString(out, `${1}[已脱敏]`) + default: + out = pattern.ReplaceAllString(out, `[已脱敏]`) + } + } + return out +} diff --git a/backend/internal/service/content_moderation_test.go b/backend/internal/service/content_moderation_test.go new file mode 100644 index 00000000..0c1a39c5 --- /dev/null +++ b/backend/internal/service/content_moderation_test.go @@ -0,0 +1,811 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type contentModerationTestSettingRepo struct { + values map[string]string +} + +func (r *contentModerationTestSettingRepo) Get(ctx context.Context, key string) (*Setting, error) { + if value, ok := r.values[key]; ok { + return &Setting{Key: key, Value: value}, nil + } + return nil, ErrSettingNotFound +} + +func (r *contentModerationTestSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + if value, ok := r.values[key]; ok { + return value, nil + } + return "", ErrSettingNotFound +} + +func (r *contentModerationTestSettingRepo) Set(ctx context.Context, key, value string) error { + if r.values == nil { + r.values = map[string]string{} + } + r.values[key] = value + return nil +} + +func (r *contentModerationTestSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := map[string]string{} + for _, key := range keys { + if value, ok := r.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (r *contentModerationTestSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + if r.values == nil { + r.values = map[string]string{} + } + for key, value := range settings { + r.values[key] = value + } + return nil +} + +func (r *contentModerationTestSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(r.values)) + for key, value := range r.values { + out[key] = value + } + return out, nil +} + +func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key string) error { + delete(r.values, key) + return nil +} + +type contentModerationTestRepo struct { + logs []ContentModerationLog +} + +func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error { + if log != nil { + r.logs = append(r.logs, *log) + } + return nil +} + +func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter ContentModerationLogFilter) ([]ContentModerationLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { + return 0, nil +} + +func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) { + return &ContentModerationCleanupResult{}, nil +} + +type contentModerationTestHashCache struct { + hashes map[string]struct{} + recorded []string + checked []string + deleted []string + hasResult bool + hasResultUsed bool +} + +type contentModerationTestUserRepo struct { + user *User + updated []User +} + +func (r *contentModerationTestUserRepo) Create(ctx context.Context, user *User) error { + panic("unexpected Create call") +} + +func (r *contentModerationTestUserRepo) GetByID(ctx context.Context, id int64) (*User, error) { + if r.user == nil { + return nil, ErrUserNotFound + } + clone := *r.user + return &clone, nil +} + +func (r *contentModerationTestUserRepo) GetByEmail(ctx context.Context, email string) (*User, error) { + panic("unexpected GetByEmail call") +} + +func (r *contentModerationTestUserRepo) GetFirstAdmin(ctx context.Context) (*User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *contentModerationTestUserRepo) Update(ctx context.Context, user *User) error { + if user == nil { + return nil + } + clone := *user + r.updated = append(r.updated, clone) + r.user = &clone + return nil +} + +func (r *contentModerationTestUserRepo) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (r *contentModerationTestUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (r *contentModerationTestUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (r *contentModerationTestUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + +func (r *contentModerationTestUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *contentModerationTestUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserIDs call") +} + +func (r *contentModerationTestUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserID call") +} + +func (r *contentModerationTestUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + panic("unexpected UpdateUserLastActiveAt call") +} + +func (r *contentModerationTestUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *contentModerationTestUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected DeductBalance call") +} + +func (r *contentModerationTestUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *contentModerationTestUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + panic("unexpected ExistsByEmail call") +} + +func (r *contentModerationTestUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *contentModerationTestUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *contentModerationTestUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *contentModerationTestUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + +func (r *contentModerationTestUserRepo) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error { + panic("unexpected UnbindUserAuthProvider call") +} + +func (r *contentModerationTestUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (r *contentModerationTestUserRepo) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (r *contentModerationTestUserRepo) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} + +type contentModerationTestAuthCacheInvalidator struct { + userIDs []int64 +} + +func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByKey(ctx context.Context, key string) { +} + +func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + i.userIDs = append(i.userIDs, userID) +} + +func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { +} + +func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error { + if c.hashes == nil { + c.hashes = map[string]struct{}{} + } + c.hashes[inputHash] = struct{}{} + c.recorded = append(c.recorded, inputHash) + return nil +} + +func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + c.checked = append(c.checked, inputHash) + if c.hasResultUsed { + return c.hasResult, nil + } + _, ok := c.hashes[inputHash] + return ok, nil +} + +func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + c.deleted = append(c.deleted, inputHash) + if c.hashes == nil { + return false, nil + } + if _, ok := c.hashes[inputHash]; !ok { + return false, nil + } + delete(c.hashes, inputHash) + return true, nil +} + +func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) { + deleted := int64(len(c.hashes)) + c.hashes = map[string]struct{}{} + return deleted, nil +} + +func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) { + return int64(len(c.hashes)), nil +} + +func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) { + svc := &ContentModerationService{} + cfg := defaultContentModerationConfig() + input := ContentModerationCheckInput{ + RequestID: "req-1", + Endpoint: "/v1/chat/completions", + Provider: "openai", + } + + log := svc.buildLog(input, cfg, ContentModerationActionAllow, true, "sexual", 0.8, map[string]float64{"sexual": 0.8}, "hello sk-proj-1234567890abcdef", nil, nil, "") + + require.NotContains(t, log.InputExcerpt, "sk-proj-1234567890abcdef") + require.Contains(t, log.InputExcerpt, "[已脱敏]") +} + +func TestRedactContentModerationSecrets_LongHexAndTokens(t *testing.T) { + input := "你哈市多大事cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554 token=abc123456789xyz Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signaturepart" + + out := redactContentModerationSecrets(input) + + require.NotContains(t, out, "cf5bbdc4cd508f3aaf0d2070d529d4a4ac29099f8ecc357f696df28e1df91554") + require.NotContains(t, out, "abc123456789xyz") + require.NotContains(t, out, "eyJhbGciOiJIUzI1NiJ9") + require.Contains(t, out, "[已脱敏]") +} + +func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.NonHitRetentionDays = 30 + + cfg.normalize() + + require.Equal(t, 3, cfg.NonHitRetentionDays) +} + +func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"old"}, + {"role":"assistant","content":"ok"}, + {"role":"user","content":[ + {"type":"text","text":"检查这张图"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"aGVsbG8="}} + ]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + require.Equal(t, "检查这张图", input.Text) + require.Equal(t, []string{"data:image/png;base64,aGVsbG8="}, input.Images) + + log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "") + require.Equal(t, "检查这张图", log.InputExcerpt) + require.NotContains(t, log.InputExcerpt, "aGVsbG8=") +} + +func TestExtractContentModerationInput_AnthropicKeepsEphemeralUserTextAndSkipsSystemReminders(t *testing.T) { + body := []byte(`{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "工具说明"}, + {"type": "text", "text": "Ainder>\n\n"}, + {"type": "text", "text": "hid", "cache_control": {"type": "ephemeral"}} + ] + } + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + + require.Equal(t, "hid", input.Text) + require.Empty(t, input.Images) +} + +func TestExtractContentModerationInput_OpenAIChatUsesLastUserMessage(t *testing.T) { + body := []byte(`{ + "model":"gpt-5.5", + "messages":[ + {"role":"system","content":"system prompt"}, + {"role":"user","content":"old user"}, + {"role":"assistant","content":"ok"}, + {"role":"user","content":[{"type":"text","text":"latest user"},{"type":"image_url","image_url":{"url":"https://example.com/a.png"}}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body) + + require.Equal(t, "latest user", input.Text) + require.Equal(t, []string{"https://example.com/a.png"}, input.Images) + require.NotContains(t, input.Text, "old user") + require.NotContains(t, input.Text, "system prompt") +} + +func TestExtractContentModerationInput_OpenAIImagesIncludesPromptAndImages(t *testing.T) { + body := []byte(`{ + "prompt":"replace background", + "images":[ + {"image_url":"https://example.com/source.png"}, + {"image_url":"data:image/png;base64,aGVsbG8="} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, body) + + require.Equal(t, "replace background", input.Text) + require.Equal(t, []string{"https://example.com/source.png", "data:image/png;base64,aGVsbG8="}, input.Images) +} + +func TestExtractContentModerationInput_OpenAIResponsesCodexPayloadUsesLastUserMessage(t *testing.T) { + body := []byte(`{ + "model":"gpt-5.5", + "instructions":"instructions.....", + "input":[ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"developer permissions sk-proj-1234567890abcdef"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]} + ], + "prompt_cache_key":"cache-key" + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body) + + require.Equal(t, "last user prompt", input.Text) + require.Empty(t, input.Images) + require.NotContains(t, input.Text, "developer permissions") + require.NotContains(t, input.Text, "first user prompt") +} + +func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *testing.T) { + var moderationRequest moderationAPIRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest)) + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.01}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.RecordNonHits = true + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + body := []byte(`{ + "model":"gpt-5.5", + "input":[ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"first user prompt"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"last user prompt"}]} + ] + }`) + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Endpoint: "/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIResponses, + Body: body, + }) + + require.NoError(t, err) + require.False(t, decision.Blocked) + require.Len(t, repo.logs, 1) + require.False(t, repo.logs[0].Flagged) + require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action) + require.Equal(t, "/responses", repo.logs[0].Endpoint) + require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt) + require.Equal(t, "last user prompt", moderationRequest.Input) +} + +func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *testing.T) { + var moderationRequest moderationAPIRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/moderations", r.URL.Path) + require.NoError(t, json.NewDecoder(r.Body).Decode(&moderationRequest)) + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.BlockStatus = http.StatusUnavailableForLegalReasons + cfg.BlockMessage = "内容审计测试阻断" + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + body := []byte(`{ + "model":"gpt-5.5", + "instructions":"instructions.....", + "input":[ + {"type":"message","role":"developer","content":[{"type":"input_text","text":"developer instructions should not be audited"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"environment context"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"latest blocked prompt"}]} + ] + }`) + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Endpoint: "/responses", + Provider: "openai", + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIResponses, + Body: body, + }) + + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionBlock, decision.Action) + require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode) + require.Equal(t, "内容审计测试阻断", decision.Message) + require.Len(t, repo.logs, 1) + require.True(t, repo.logs[0].Flagged) + require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action) + require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode) + require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt) + require.Equal(t, "latest blocked prompt", moderationRequest.Input) +} + +func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) { + result := buildContentModerationTestAuditResult(&moderationAPIResult{ + Flagged: true, + CategoryScores: map[string]float64{ + "harassment": 0.65, + }, + }, nil) + + require.NotNil(t, result) + require.False(t, result.Flagged) + require.Equal(t, "harassment", result.HighestCategory) + require.Equal(t, 0.65, result.HighestScore) + require.Equal(t, 0.65, result.CompositeScore) + require.Equal(t, 0.98, result.Thresholds["harassment"]) +} + +func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.PreHashCheckEnabled = true + cfg.APIKeys = []string{"sk-test"} + cfg.BlockStatus = http.StatusConflict + cfg.BlockMessage = "命中历史风险输入" + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{}} + content := ContentModerationInput{Text: "blocked prompt"} + content.Normalize() + hashCache.hashes[content.Hash()] = struct{}{} + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + &contentModerationTestRepo{}, + hashCache, + nil, + nil, + nil, + nil, + ) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionHashBlock, decision.Action) + require.Equal(t, http.StatusConflict, decision.StatusCode) + require.Equal(t, content.Hash(), decision.InputHash) + require.Contains(t, decision.Message, "命中历史风险输入") + require.Contains(t, decision.Message, content.Hash()) + require.Len(t, hashCache.checked, 1) +} + +func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) { + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.PreHashCheckEnabled = true + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.BlockStatus = http.StatusConflict + cfg.BlockMessage = "命中风险输入" + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + hashCache := &contentModerationTestHashCache{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + hashCache, + nil, + nil, + nil, + nil, + ) + + body := []byte(`{"messages":[{"role":"user","content":"repeat blocked prompt"}]}`) + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: body, + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionBlock, decision.Action) + require.Equal(t, 1, requestCount) + require.Len(t, hashCache.recorded, 1) + require.Len(t, repo.logs, 1) + + decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: body, + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionHashBlock, decision.Action) + require.Equal(t, hashCache.recorded[0], decision.InputHash) + require.Equal(t, 1, requestCount) + require.Len(t, repo.logs, 1) +} + +func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) { + existingHash := strings.Repeat("a", 64) + hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{ + existingHash: {}, + }} + svc := &ContentModerationService{hashCache: hashCache} + + result, err := svc.DeleteFlaggedInputHash(context.Background(), strings.ToUpper(existingHash)) + + require.NoError(t, err) + require.Equal(t, existingHash, result.InputHash) + require.True(t, result.Deleted) + require.NotContains(t, hashCache.hashes, existingHash) + require.Equal(t, []string{existingHash}, hashCache.deleted) + + result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash) + + require.NoError(t, err) + require.Equal(t, existingHash, result.InputHash) + require.False(t, result.Deleted) +} + +func TestContentModerationClearFlaggedInputHashesAndStatusCount(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + hashCache := &contentModerationTestHashCache{hashes: map[string]struct{}{ + strings.Repeat("a", 64): {}, + strings.Repeat("b", 64): {}, + }} + svc := &ContentModerationService{ + settingRepo: &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + hashCache: hashCache, + keyHealth: make(map[string]*contentModerationKeyHealth), + } + + status, err := svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(2), status.FlaggedHashCount) + + result, err := svc.ClearFlaggedInputHashes(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(2), result.Deleted) + + status, err = svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(0), status.FlaggedHashCount) +} + +func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModeObserve + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestRepo{} + hashCache := &contentModerationTestHashCache{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + hashCache, + nil, + nil, + nil, + nil, + ) + + decision := svc.checkSync(context.Background(), ContentModerationCheckInput{ + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"bad prompt"}]}`), + }, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false) + + require.False(t, decision.Blocked) + require.Len(t, hashCache.recorded, 1) + require.Len(t, repo.logs, 1) +} + +func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) { + userID := int64(1001) + cfg := defaultContentModerationConfig() + cfg.BanThreshold = 10 + body := buildContentModerationAccountDisabledEmailBody("Sub2API ", &ContentModerationLog{ + UserID: &userID, + UserEmail: "user@example.com", + GroupName: "vip_2", + HighestCategory: "sexual", + HighestScore: 0.926, + ViolationCount: 10, + }, cfg) + + require.Contains(t, body, "账户已被自动禁用") + require.Contains(t, body, "封禁详情") + require.Contains(t, body, "账户当前处于封禁状态,所有 API 请求将被拒绝") + require.Contains(t, body, "10 次(阈值 10)") + require.Contains(t, body, "sexual / 0.926") + require.Contains(t, body, "Sub2API <Admin>") +} + +func TestContentModerationUnbanUser_ActivatesUserAndInvalidatesAuthCache(t *testing.T) { + userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusDisabled}} + invalidator := &contentModerationTestAuthCacheInvalidator{} + repo := &contentModerationTestRepo{} + svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil) + + result, err := svc.UnbanUser(context.Background(), 1001) + + require.NoError(t, err) + require.Equal(t, int64(1001), result.UserID) + require.Equal(t, StatusActive, result.Status) + require.Len(t, userRepo.updated, 1) + require.Equal(t, StatusActive, userRepo.updated[0].Status) + require.Equal(t, []int64{1001}, invalidator.userIDs) +} + +func TestContentModerationUnbanUser_ActiveUserOnlyInvalidatesAuthCache(t *testing.T) { + userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Email: "user@example.com", Status: StatusActive}} + invalidator := &contentModerationTestAuthCacheInvalidator{} + repo := &contentModerationTestRepo{} + svc := NewContentModerationService(nil, repo, nil, nil, userRepo, invalidator, nil) + + result, err := svc.UnbanUser(context.Background(), 1001) + + require.NoError(t, err) + require.Equal(t, StatusActive, result.Status) + require.Empty(t, userRepo.updated) + require.Equal(t, []int64{1001}, invalidator.userIDs) +} + +func contentModerationIntPtr(v int) *int { + return &v +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 24b511e3..14147fff 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -107,6 +107,8 @@ const ( SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结) SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久) SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限) + SettingKeyRiskControlEnabled = "risk_control_enabled" // 是否启用风控中心入口与审计链路 + SettingKeyContentModerationConfig = "content_moderation_config" // 内容审计配置(JSON) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 04be5164..afa94156 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -90,6 +90,69 @@ type OpenAIImagesRequest struct { bodyHash string } +func (r *OpenAIImagesRequest) ModerationBody() []byte { + if r == nil { + return nil + } + payload := map[string]any{} + if prompt := strings.TrimSpace(r.Prompt); prompt != "" { + payload["prompt"] = prompt + } + images := r.moderationImages() + if len(images) > 0 { + payload["images"] = images + } + if len(payload) == 0 { + return nil + } + body, err := json.Marshal(payload) + if err != nil { + return nil + } + return body +} + +func (r *OpenAIImagesRequest) moderationImages() []map[string]string { + if r == nil { + return nil + } + images := make([]map[string]string, 0, len(r.InputImageURLs)+len(r.Uploads)+1) + for _, imageURL := range r.InputImageURLs { + imageURL = strings.TrimSpace(imageURL) + if imageURL != "" { + images = append(images, map[string]string{"image_url": imageURL}) + } + } + for _, upload := range r.Uploads { + if dataURL := upload.ModerationDataURL(); dataURL != "" { + images = append(images, map[string]string{"image_url": dataURL}) + } + } + if maskURL := strings.TrimSpace(r.MaskImageURL); maskURL != "" { + images = append(images, map[string]string{"image_url": maskURL}) + } + if r.MaskUpload != nil { + if dataURL := r.MaskUpload.ModerationDataURL(); dataURL != "" { + images = append(images, map[string]string{"image_url": dataURL}) + } + } + return images +} + +func (u OpenAIImagesUpload) ModerationDataURL() string { + if len(u.Data) == 0 { + return "" + } + contentType := strings.TrimSpace(u.ContentType) + if contentType == "" { + contentType = http.DetectContentType(u.Data) + } + if !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return "" + } + return fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(u.Data)) +} + func (r *OpenAIImagesRequest) IsEdits() bool { return r != nil && r.Endpoint == openAIImagesEditsEndpoint } diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index fa4a4415..45fb24e9 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -90,6 +90,51 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } +func TestOpenAIImagesRequestModerationBody_JSONEditIncludesInputImageURLs(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Prompt: "replace background", + InputImageURLs: []string{"https://example.com/source.png"}, + MaskImageURL: "https://example.com/mask.png", + } + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody()) + + require.Equal(t, "replace background", input.Text) + require.Equal(t, []string{"https://example.com/source.png", "https://example.com/mask.png"}, input.Images) +} + +func TestOpenAIImagesRequestModerationBody_MultipartEditIncludesUploadsInMemory(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Prompt: "replace background", + Uploads: []OpenAIImagesUpload{{ + FieldName: "image", + FileName: "source.png", + ContentType: "image/png", + Data: []byte("fake-image-bytes"), + }}, + MaskUpload: &OpenAIImagesUpload{ + FieldName: "mask", + FileName: "mask.png", + ContentType: "image/png", + Data: []byte("fake-mask-bytes"), + }, + } + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIImages, parsed.ModerationBody()) + + require.Equal(t, "replace background", input.Text) + require.Equal(t, []string{ + "data:image/png;base64,ZmFrZS1pbWFnZS1ieXRlcw==", + "data:image/png;base64,ZmFrZS1tYXNrLWJ5dGVz", + }, input.Images) + + log := (&ContentModerationService{}).buildLog(ContentModerationCheckInput{}, defaultContentModerationConfig(), ContentModerationActionAllow, false, "", 0, nil, input.ExcerptText(), nil, nil, "") + require.Equal(t, "replace background", log.InputExcerpt) + require.NotContains(t, log.InputExcerpt, "ZmFrZS") +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 784cdbe5..372f420f 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -223,6 +223,7 @@ type OpenAIWSIngressHooks struct { // 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。 InitialRequestModel string BeforeTurn func(turn int) error + BeforeRequest func(turn int, payload []byte, originalModel string) error AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) } @@ -3222,6 +3223,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return true } for { + if turn > 1 && !skipBeforeTurn && hooks != nil && hooks.BeforeRequest != nil { + if err := hooks.BeforeRequest(turn, currentPayload, currentOriginalModel); err != nil { + return err + } + } if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { if err := hooks.BeforeTurn(turn); err != nil { return err diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 8bc17d42..e2760725 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -387,6 +387,19 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( if msgType != coderws.MessageText { return payload, nil, nil } + if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" && hooks != nil && hooks.BeforeRequest != nil { + turnNo := int(completedTurns.Load()) + 1 + if turnNo < 2 { + turnNo = 2 + } + requestModel := usageMeta.requestModelForFrame(payload) + if requestModel == "" { + requestModel = capturedSessionModel + } + if err := hooks.BeforeRequest(turnNo, payload, requestModel); err != nil { + return payload, nil, err + } + } // 在评估策略前先刷新 capturedSessionModel:客户端可能通过 // session.update 修改 session-level model(Realtime / // Responses WS 协议允许),如果不刷新就会出现 diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a5d65ad7..bf6294db 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -456,6 +456,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyChannelMonitorDefaultIntervalSeconds, SettingKeyAvailableChannelsEnabled, SettingKeyAffiliateEnabled, + SettingKeyRiskControlEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -545,6 +546,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true", AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true", + + RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true", }, nil } @@ -692,6 +695,7 @@ type PublicSettingsInjectionPayload struct { ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` AvailableChannelsEnabled bool `json:"available_channels_enabled"` AffiliateEnabled bool `json:"affiliate_enabled"` + RiskControlEnabled bool `json:"risk_control_enabled"` } // GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection. @@ -745,6 +749,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: settings.AvailableChannelsEnabled, AffiliateEnabled: settings.AffiliateEnabled, + RiskControlEnabled: settings.RiskControlEnabled, }, nil } @@ -1232,6 +1237,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // Affiliate (邀请返利) feature switch updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled) + // 风控中心功能开关 + updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled) + // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion @@ -1903,6 +1911,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // Affiliate (邀请返利) feature (default disabled; opt-in) SettingKeyAffiliateEnabled: "false", + // 风控中心功能(默认关闭,显式启用) + SettingKeyRiskControlEnabled: "false", + // Claude Code version check (default: empty = disabled) SettingKeyMinClaudeCodeVersion: "", SettingKeyMaxClaudeCodeVersion: "", @@ -2242,6 +2253,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Affiliate (邀请返利) feature (default: disabled; strict true) result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true" + // 风控中心功能(默认关闭,严格 true 才启用) + result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true" + // Claude Code version check result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index aaf837bd..46a8c5a8 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -106,6 +106,7 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + RiskControlEnabled bool AffiliateEnabled bool AffiliateRebateRate float64 AffiliateRebateFreezeHours int @@ -233,6 +234,9 @@ type PublicSettings struct { // Affiliate (邀请返利) feature toggle AffiliateEnabled bool `json:"affiliate_enabled"` + + // 风控中心功能开关 + RiskControlEnabled bool `json:"risk_control_enabled"` } type WeChatConnectOAuthConfig struct { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 0f36412b..dc96be0c 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -509,6 +509,7 @@ var ProviderSet = wire.NewSet( NewGroupCapacityService, NewChannelService, NewModelPricingResolver, + NewContentModerationService, NewAffiliateService, ProvidePaymentConfigService, NewPaymentService, diff --git a/backend/migrations/135_content_moderation.sql b/backend/migrations/135_content_moderation.sql new file mode 100644 index 00000000..4873bbf2 --- /dev/null +++ b/backend/migrations/135_content_moderation.sql @@ -0,0 +1,45 @@ +-- 风控中心内容审计配置与记录 + +INSERT INTO settings (key, value, updated_at) +VALUES ('risk_control_enabled', 'false', NOW()) +ON CONFLICT (key) DO NOTHING; + +CREATE TABLE IF NOT EXISTS content_moderation_logs ( + id BIGSERIAL PRIMARY KEY, + request_id VARCHAR(128) NOT NULL DEFAULT '', + user_id BIGINT REFERENCES users(id) ON DELETE SET NULL, + user_email VARCHAR(255) NOT NULL DEFAULT '', + api_key_id BIGINT REFERENCES api_keys(id) ON DELETE SET NULL, + api_key_name VARCHAR(100) NOT NULL DEFAULT '', + group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL, + group_name VARCHAR(255) NOT NULL DEFAULT '', + endpoint VARCHAR(128) NOT NULL DEFAULT '', + provider VARCHAR(64) NOT NULL DEFAULT '', + model VARCHAR(255) NOT NULL DEFAULT '', + mode VARCHAR(32) NOT NULL DEFAULT '', + action VARCHAR(32) NOT NULL DEFAULT '', + flagged BOOLEAN NOT NULL DEFAULT FALSE, + highest_category VARCHAR(64) NOT NULL DEFAULT '', + highest_score DECIMAL(8, 6) NOT NULL DEFAULT 0, + category_scores JSONB NOT NULL DEFAULT '{}'::jsonb, + threshold_snapshot JSONB NOT NULL DEFAULT '{}'::jsonb, + input_excerpt TEXT NOT NULL DEFAULT '', + upstream_latency_ms INT, + error TEXT NOT NULL DEFAULT '', + violation_count INT NOT NULL DEFAULT 0, + auto_banned BOOLEAN NOT NULL DEFAULT FALSE, + email_sent BOOLEAN NOT NULL DEFAULT FALSE, + queue_delay_ms INT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS violation_count INT NOT NULL DEFAULT 0; +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS auto_banned BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS email_sent BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE content_moderation_logs ADD COLUMN IF NOT EXISTS queue_delay_ms INT; +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_created_at ON content_moderation_logs(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_group_created_at ON content_moderation_logs(group_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_flagged_created_at ON content_moderation_logs(flagged, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_user_created_at ON content_moderation_logs(user_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_api_key_created_at ON content_moderation_logs(api_key_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_content_moderation_logs_endpoint_created_at ON content_moderation_logs(endpoint, created_at DESC); diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 80241794..384e3796 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -30,6 +30,7 @@ import channelMonitorAPI from './channelMonitor' import channelMonitorTemplateAPI from './channelMonitorTemplate' import adminPaymentAPI from './payment' import affiliatesAPI from './affiliates' +import riskControlAPI from './riskControl' /** * Unified admin API object for convenient access @@ -61,7 +62,8 @@ export const adminAPI = { channelMonitor: channelMonitorAPI, channelMonitorTemplate: channelMonitorTemplateAPI, payment: adminPaymentAPI, - affiliates: affiliatesAPI + affiliates: affiliatesAPI, + riskControl: riskControlAPI } export { @@ -91,7 +93,8 @@ export { channelMonitorAPI, channelMonitorTemplateAPI, adminPaymentAPI, - affiliatesAPI + affiliatesAPI, + riskControlAPI } export default adminAPI @@ -101,3 +104,4 @@ export type { BalanceHistoryItem } from './users' export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' export type { BackupAgentHealth, DataManagementConfig } from './dataManagement' export type { TLSFingerprintProfile, CreateProfileRequest, UpdateProfileRequest } from './tlsFingerprintProfile' +export type { ContentModerationConfig, ContentModerationLog, ModerationMode } from './riskControl' diff --git a/frontend/src/api/admin/riskControl.ts b/frontend/src/api/admin/riskControl.ts new file mode 100644 index 00000000..5f42b01d --- /dev/null +++ b/frontend/src/api/admin/riskControl.ts @@ -0,0 +1,251 @@ +import { apiClient } from '../client' + +export type ModerationMode = 'off' | 'observe' | 'pre_block' + +export interface ContentModerationConfig { + enabled: boolean + mode: ModerationMode + base_url: string + model: string + api_key_configured: boolean + api_key_masked: string + api_key_count: number + api_key_masks: string[] + api_key_statuses: ContentModerationAPIKeyStatus[] + timeout_ms: number + sample_rate: number + all_groups: boolean + group_ids: number[] + record_non_hits: boolean + worker_count: number + queue_size: number + block_status: number + block_message: string + email_on_hit: boolean + auto_ban_enabled: boolean + ban_threshold: number + violation_window_hours: number + retry_count: number + hit_retention_days: number + non_hit_retention_days: number + pre_hash_check_enabled: boolean +} + +export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen' + +export interface ContentModerationAPIKeyStatus { + index: number + key_hash: string + masked: string + status: ContentModerationAPIKeyStatusValue + failure_count: number + success_count: number + last_error: string + last_checked_at?: string + frozen_until?: string + last_latency_ms: number + last_http_status: number + last_tested: boolean + configured: boolean +} + +export interface TestContentModerationAPIKeysPayload { + api_keys?: string[] + base_url?: string + model?: string + timeout_ms?: number + prompt?: string + images?: string[] +} + +export interface TestContentModerationAPIKeysResponse { + items: ContentModerationAPIKeyStatus[] + audit_result?: ContentModerationTestAuditResult + image_count: number +} + +export interface ContentModerationTestAuditResult { + flagged: boolean + highest_category: string + highest_score: number + composite_score: number + category_scores: Record + thresholds: Record +} + +export interface UpdateContentModerationConfig { + enabled?: boolean + mode?: ModerationMode + base_url?: string + model?: string + api_key?: string + api_keys?: string[] + clear_api_key?: boolean + timeout_ms?: number + sample_rate?: number + all_groups?: boolean + group_ids?: number[] + record_non_hits?: boolean + worker_count?: number + queue_size?: number + block_status?: number + block_message?: string + email_on_hit?: boolean + auto_ban_enabled?: boolean + ban_threshold?: number + violation_window_hours?: number + retry_count?: number + hit_retention_days?: number + non_hit_retention_days?: number + pre_hash_check_enabled?: boolean +} + +export interface ContentModerationRuntimeStatus { + enabled: boolean + risk_control_enabled: boolean + mode: ModerationMode + worker_count: number + max_workers: number + active_workers: number + idle_workers: number + queue_size: number + queue_length: number + queue_usage_percent: number + enqueued: number + dropped: number + processed: number + errors: number + api_key_statuses: ContentModerationAPIKeyStatus[] + flagged_hash_count: number + last_cleanup_at?: string + last_cleanup_deleted_hit: number + last_cleanup_deleted_non_hit: number +} + +export interface ContentModerationLog { + id: number + request_id: string + user_id: number | null + user_email: string + api_key_id: number | null + api_key_name: string + group_id: number | null + group_name: string + endpoint: string + provider: string + model: string + mode: string + action: string + flagged: boolean + highest_category: string + highest_score: number + category_scores: Record + threshold_snapshot: Record + input_excerpt: string + upstream_latency_ms: number | null + error: string + violation_count: number + auto_banned: boolean + email_sent: boolean + user_status: string + queue_delay_ms: number | null + created_at: string +} + +export interface ListContentModerationLogsParams { + page?: number + page_size?: number + result?: string + group_id?: number + endpoint?: string + search?: string + from?: string + to?: string +} + +export interface ContentModerationLogsResponse { + items: ContentModerationLog[] + total: number + page: number + page_size: number + pages: number +} + +export interface ContentModerationUnbanUserResponse { + user_id: number + status: string +} + +export interface DeleteFlaggedHashResponse { + input_hash: string + deleted: boolean +} + +export interface ClearFlaggedHashesResponse { + deleted: number +} + +export async function getConfig(): Promise { + const { data } = await apiClient.get('/admin/risk-control/config') + return data +} + +export async function updateConfig( + payload: UpdateContentModerationConfig +): Promise { + const { data } = await apiClient.put('/admin/risk-control/config', payload) + return data +} + +export async function getStatus(): Promise { + const { data } = await apiClient.get('/admin/risk-control/status') + return data +} + +export async function testAPIKeys( + payload: TestContentModerationAPIKeysPayload = {} +): Promise { + const { data } = await apiClient.post('/admin/risk-control/api-keys/test', payload) + return data +} + +export async function listLogs( + params: ListContentModerationLogsParams = {} +): Promise { + const { data } = await apiClient.get('/admin/risk-control/logs', { + params, + }) + return data +} + +export async function unbanUser(userID: number): Promise { + const { data } = await apiClient.post( + `/admin/risk-control/users/${userID}/unban` + ) + return data +} + +export async function deleteFlaggedHash(inputHash: string): Promise { + const { data } = await apiClient.delete('/admin/risk-control/hashes', { + data: { input_hash: inputHash }, + }) + return data +} + +export async function clearFlaggedHashes(): Promise { + const { data } = await apiClient.delete('/admin/risk-control/hashes/all') + return data +} + +export const riskControlAPI = { + getConfig, + updateConfig, + getStatus, + testAPIKeys, + listLogs, + unbanUser, + deleteFlaggedHash, + clearFlaggedHashes, +} + +export default riskControlAPI diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 4b4f7c23..77841561 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -444,6 +444,7 @@ export interface SystemSettings { // Payment configuration payment_enabled: boolean; + risk_control_enabled: boolean; payment_min_amount: number; payment_max_amount: number; payment_daily_limit: number; @@ -613,6 +614,7 @@ export interface UpdateSettingsRequest { enable_anthropic_cache_ttl_1h_injection?: boolean; // Payment configuration payment_enabled?: boolean; + risk_control_enabled?: boolean; payment_min_amount?: number; payment_max_amount?: number; payment_daily_limit?: number; diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 4488bf60..3d7f1604 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -593,6 +593,21 @@ const SignalIcon = { ) } +const ShieldIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M9 12.75L11.25 15 15 9.75m-3-7.036A11.959 11.959 0 013.598 6 11.99 11.99 0 003 9.749c0 5.592 3.824 10.29 9 11.623 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751h-.152c-3.196 0-6.1-1.248-8.25-3.285z' + }) + ] + ) +} + const PriceTagIcon = { render: () => h( @@ -635,6 +650,7 @@ const flagChannelMonitor = makeSidebarFlag(FeatureFlags.channelMonitor) const flagPayment = makeSidebarFlag(FeatureFlags.payment) const flagAvailableChannels = makeSidebarFlag(FeatureFlags.availableChannels) const flagAffiliate = makeSidebarFlag(FeatureFlags.affiliate) +const flagRiskControl = makeSidebarFlag(FeatureFlags.riskControl) const flagOpsMonitoring = () => adminSettingsStore.opsMonitoringEnabled const flagAdminPayment = () => adminSettingsStore.paymentEnabled @@ -719,6 +735,7 @@ const adminNavItems = computed((): NavItem[] => { { path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon }, { path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon }, { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, + { path: '/admin/risk-control', label: t('nav.riskControl'), icon: ShieldIcon, hideInSimpleMode: true, featureFlag: flagRiskControl }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 5f968ac7..5c2cdfe1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -382,6 +382,7 @@ export default { channelPricing: 'Channel Pricing', channelMonitor: 'Channel Monitor', channelStatus: 'Channel Status', + riskControl: 'Risk Control', }, // Auth @@ -410,6 +411,9 @@ export default { passwordRequired: 'Password is required', passwordMinLength: 'Password must be at least 6 characters', loginFailed: 'Login failed. Please check your credentials and try again.', + errors: { + USER_NOT_ACTIVE: 'Account has been disabled.', + }, registrationFailed: 'Registration failed. Please try again.', emailSuffixNotAllowed: 'This email domain is not allowed for registration.', emailSuffixNotAllowedWithAllowed: @@ -2305,6 +2309,200 @@ export default { } }, + riskControl: { + title: 'Risk Control', + description: 'Configure content moderation and review audit records', + loadFailed: 'Failed to load risk control', + saveFailed: 'Failed to save content moderation config', + logsFailed: 'Failed to load audit records', + saved: 'Content moderation config saved', + refresh: 'Refresh', + config: 'Content Moderation Config', + configHint: 'Use OpenAI Moderations to score request content and handle threshold hits by mode.', + openSettings: 'Moderation Settings', + settingsTitle: 'Content Moderation Settings', + refreshStatus: 'Refresh Status', + records: 'Audit Records', + recordsHint: 'Shows hits, blocks, errors, and sampled records.', + saveConfig: 'Save Moderation Config', + statusFailed: 'Failed to load runtime status', + enabled: 'Enable Content Moderation', + enabledHint: 'When off, gateway requests are not moderated even if the menu is enabled.', + mode: 'Global Mode', + modePreBlock: 'Pre-Block', + modePreBlockDesc: 'Synchronously reviews the latest user input before every request and rejects hits immediately.', + modeObserve: 'Observe Only', + modeObserveDesc: 'Requests pass through while the latest user input is queued for async review; hits are recorded, notified, and counted.', + modeOff: 'Off', + modeOffDesc: 'Content moderation is disabled and no audit records are written.', + baseUrl: 'OpenAI Base URL', + model: 'Model', + apiKey: 'OpenAI API Key', + apiKeys: 'OpenAI API Keys', + apiKeyCount: '{count} keys', + apiKeyPlaceholder: 'Enter API Key', + apiKeysPlaceholder: 'One API Key per line', + apiKeysPlaceholderKeep: 'Leave empty to keep stored keys; enter values to replace them', + apiKeysHint: '{count} keys are currently stored. Values entered here replace stored keys; leave empty to keep them.', + apiKeyPlaceholderKeep: 'Leave empty to keep current key', + apiKeyWillClear: 'Configured key will be cleared on save', + apiKeyConfigured: 'Configured', + apiKeyTemporary: 'Pending', + inputApiKeyCount: '{count} keys in input', + storedApiKeyCount: '{count} stored keys', + testInputApiKeys: 'Test input keys', + testStoredApiKeys: 'Test stored keys', + testContentWithStoredApiKey: 'Test content with stored key', + testingApiKeys: 'Testing', + apiKeyTestNoInput: 'Enter OpenAI API Keys to test first', + apiKeyTestDone: 'Key test completed for {count} keys', + apiKeyTestFailed: 'Failed to test OpenAI API Keys', + apiKeyHealth: 'Key Availability', + apiKeyFreezeRule: 'Three consecutive failures freeze a key for 1 minute; moderation rotation skips frozen keys.', + apiKeyRows: '{count} keys', + apiKeyHealthEmpty: 'No key status yet', + apiKeyHealthEmptyHint: 'Save keys or test input keys to see availability.', + apiKeyStatusOk: 'Available', + apiKeyStatusError: 'Error', + apiKeyStatusFrozen: 'Frozen', + apiKeyStatusUnknown: 'Untested', + apiKeyFailureCount: '{count} failures', + apiKeyLatency: '{ms} ms', + apiKeyHTTPStatus: 'HTTP {status}', + apiKeyFrozenUntil: 'Frozen until {time}', + apiKeyLastChecked: 'Checked at {time}', + apiKeyNotTested: 'Not tested', + auditTestInput: 'Audit Test Input', + auditTestInputHint: 'Enter a prompt and upload or paste images; images are sent as base64 and are not stored.', + auditTestPromptPlaceholder: 'Enter a user prompt to test; leave empty to only test key availability.', + auditTestImages: 'Test Images', + auditTestImagesHint: 'Upload, drag, or paste images. Up to 4 images, 8MB each.', + addAuditTestImage: 'Add image', + clearAuditTest: 'Clear test', + auditTestImageLimit: 'You can add up to {count} test images', + auditTestImageTooLarge: 'Each test image must be 8MB or smaller', + auditTestImageReadFailed: 'Failed to read test image', + auditTestResult: 'Audit Test Result', + auditTestHighest: 'Top category {category}, score {score}', + auditTestComposite: 'Composite score', + auditTestFlagged: 'Threshold hit', + auditTestPassed: 'Pass', + notConfigured: 'Not configured', + clearApiKey: 'Clear stored key', + keepApiKey: 'Keep stored key', + timeoutMs: 'HTTP Timeout (ms)', + retryCount: 'Retry Count', + sampleRate: 'Sample Rate', + recordNonHits: 'Record Non-Hits', + recordNonHitsHint: 'When enabled, sampled non-hit request summaries are redacted before storage.', + preHashCheck: 'Enable Pre-Hash Check', + preHashCheckHint: 'Hashes from async hits are blocked before moderation; this does not send email or increment ban counters.', + flaggedHashCount: 'Current hash collection size: {count}', + flaggedHashHint: 'Hashes are stored permanently in Redis; paste a full 64-character hash to remove a false block, or clear all stored hashes.', + flaggedHashPlaceholder: 'Paste full 64-character input hash', + deleteFlaggedHash: 'Delete hash', + clearFlaggedHashes: 'Clear all', + clearFlaggedHashesConfirm: 'Clear all risk input hashes? This does not delete audit records, but removes all historical hash blocks.', + flaggedHashDeleted: 'Risk hash deleted', + flaggedHashNotFound: 'Risk hash not found', + flaggedHashDeleteFailed: 'Failed to delete risk hash', + flaggedHashesCleared: 'Cleared {count} risk hashes', + flaggedHashesClearFailed: 'Failed to clear risk hashes', + workerCount: 'Worker Count', + queueSize: 'Async Queue Size', + blockStatus: 'Block HTTP Status', + blockMessage: 'Custom Block Message', + emailOnHit: 'Email on Hit', + emailOnHitHint: 'When enabled, send a risk-control email on every hit; auto-ban notices are always sent.', + autoBan: 'Auto Ban User', + autoBanHint: 'Disable the user, invalidate auth cache, and send a ban notice after the hit threshold is reached.', + banThreshold: 'Ban Threshold', + violationWindowHours: 'Count Window (hours)', + hitRetentionDays: 'Hit Record Retention (days)', + nonHitRetentionDays: 'Non-Hit Record Retention (days, max 3)', + violationCount: '{count} hits', + emailSent: 'Email sent', + emailNotSent: 'No email', + autoBanned: 'Banned', + unbanUser: 'Unban', + unbanSuccess: 'User has been unbanned', + unbanFailed: 'Failed to unban user', + inputDetailTitle: 'Input Summary Detail', + inputDetailContent: 'Full Content', + queueDelay: 'Queued {ms} ms', + allGroups: 'All Groups', + allGroupsHint: 'Auditing all groups', + selectedGroupsHint: 'Auditing selected groups', + groupScope: 'Audit Groups', + groupScopeHint: 'Switch on for all groups, or turn off to choose specific groups.', + selectedGroups: 'Selected Groups', + searchGroups: 'Search group name or platform', + noGroups: 'No groups available', + emptyLogs: 'No audit records', + workerStatus: 'Worker Runtime', + workerStatusHint: 'Queue and worker pool status for asynchronous observation tasks.', + workerPool: 'Worker Pool', + workerPoolMeta: '{active} processing, {idle} idle and ready, {total} total', + queueUsage: 'Queue Usage', + activeWorkers: 'Processing', + idleWorkers: 'Idle Ready', + workerActive: 'Processing an asynchronous audit task', + workerIdle: 'Started, idle and ready', + workerDisabled: 'Risk control or content audit is disabled', + processed: 'Processed', + droppedErrors: 'Dropped / Errors', + autoRefresh: 'Auto refresh every 15s', + lastCleanup: 'Last cleanup: {time}', + cleanupStats: 'Last cleanup deleted {hit} hits and {nonHit} non-hits', + riskSwitchOff: 'System switch off', + tabs: { + basic: 'Basic', + scope: 'Scope', + runtime: 'Runtime', + response: 'Hit Notice', + retention: 'Retention', + }, + overview: { + status: 'Status', + enabled: 'Enabled', + disabled: 'Disabled', + apiKey: 'API Key', + groupScope: 'Scope', + logs: 'Audit Records', + currentFilter: 'Current filter', + }, + filters: { + search: 'Search user/key/summary', + from: 'From', + to: 'To', + allGroups: 'All Groups', + allEndpoints: 'All Endpoints', + }, + table: { + time: 'Time', + group: 'Group', + user: 'User', + apiKey: 'API Key', + endpoint: 'Endpoint', + result: 'Result', + highest: 'Highest', + actionMeta: 'Action', + latency: 'Latency', + input: 'Input Summary', + }, + result: { + all: 'All Results', + hit: 'Hit', + blocked: 'Blocked', + pass: 'Pass', + error: 'Error', + }, + action: { + block: 'Blocked', + error: 'Error', + }, + }, + // Channel Monitor channelMonitor: { title: 'Channel Monitor', @@ -4862,6 +5060,13 @@ export default { enabled: 'Enable Available Channels', enabledHint: 'When off, the sidebar entry is hidden and the endpoint returns an empty list.', }, + riskControl: { + title: 'Risk Control', + description: 'Enable the content moderation menu and gateway audit entry point. Disabled by default.', + configureLink: 'Configure content moderation in Risk Control', + enabled: 'Enable Risk Control', + enabledHint: 'When off, the admin sidebar entry is hidden and gateway moderation is skipped.', + }, affiliate: { title: 'Affiliate (Invite Rebate)', description: 'Existing users invite new ones; the inviter earns a percentage rebate on the invitee’s recharges. Disabled by default.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a37a9786..3d188825 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -382,6 +382,7 @@ export default { channelPricing: '渠道定价', channelMonitor: '渠道监控', channelStatus: '渠道状态', + riskControl: '风控中心', }, // Auth @@ -410,6 +411,9 @@ export default { passwordRequired: '请输入密码', passwordMinLength: '密码至少需要 6 个字符', loginFailed: '登录失败,请检查您的凭据后重试。', + errors: { + USER_NOT_ACTIVE: '账号已被禁用', + }, registrationFailed: '注册失败,请重试。', emailSuffixNotAllowed: '该邮箱域名不在允许注册范围内。', emailSuffixNotAllowedWithAllowed: '该邮箱域名不被允许。可用域名:{suffixes}', @@ -2382,6 +2386,200 @@ export default { } }, + riskControl: { + title: '风控中心', + description: '配置内容审计策略并查看审核记录', + loadFailed: '加载风控中心失败', + saveFailed: '保存内容审计配置失败', + logsFailed: '加载审核记录失败', + saved: '内容审计配置已保存', + refresh: '刷新', + config: '内容审计配置', + configHint: '调用 OpenAI Moderations 进行请求内容评分,命中阈值后按模式处理。', + openSettings: '内容审计设置', + settingsTitle: '内容审计设置', + refreshStatus: '刷新状态', + records: '审核记录', + recordsHint: '展示命中、拦截、异常和已采样记录。', + saveConfig: '保存内容审计配置', + statusFailed: '加载运行状态失败', + enabled: '开启内容审计', + enabledHint: '关闭后即使风控中心菜单启用,也不会审核网关请求。', + mode: '全局模式', + modePreBlock: '前置拦截', + modePreBlockDesc: '每次请求先同步审核最新用户输入,命中后立即拒绝请求。', + modeObserve: '仅观察', + modeObserveDesc: '请求直接放行,最新用户输入进入异步审核队列;命中后只记录、通知和按规则累计。', + modeOff: '关闭', + modeOffDesc: '不执行内容审计,也不会写入审核记录。', + baseUrl: 'OpenAI Base URL', + model: '模型名', + apiKey: 'OpenAI API Key', + apiKeys: 'OpenAI API Keys', + apiKeyCount: '{count} 个 Key', + apiKeyPlaceholder: '请输入 API Key', + apiKeysPlaceholder: '每行一个 API Key', + apiKeysPlaceholderKeep: '留空保持已保存的 Key;填写后将替换为这些 Key', + apiKeysHint: '当前已保存 {count} 个 Key;填写文本框会替换已保存 Key,留空则保持不变。', + apiKeyPlaceholderKeep: '留空保持不变', + apiKeyWillClear: '保存后清除已配置 Key', + apiKeyConfigured: '已配置', + apiKeyTemporary: '待保存', + inputApiKeyCount: '输入区 {count} 个 Key', + storedApiKeyCount: '已保存 {count} 个 Key', + testInputApiKeys: '测试输入区 Key', + testStoredApiKeys: '测试已保存 Key', + testContentWithStoredApiKey: '用已保存 Key 试跑内容', + testingApiKeys: '测试中', + apiKeyTestNoInput: '请先输入需要测试的 OpenAI API Key', + apiKeyTestDone: 'Key 测试完成,共 {count} 个', + apiKeyTestFailed: '测试 OpenAI API Key 失败', + apiKeyHealth: 'Key 可用状态', + apiKeyFreezeRule: '连续 3 次失败会冻结 1 分钟,审计轮询会自动跳过。', + apiKeyRows: '{count} 个', + apiKeyHealthEmpty: '暂无 Key 状态', + apiKeyHealthEmptyHint: '保存 Key 或测试输入区 Key 后会显示可用性。', + apiKeyStatusOk: '可用', + apiKeyStatusError: '异常', + apiKeyStatusFrozen: '冻结', + apiKeyStatusUnknown: '未测试', + apiKeyFailureCount: '失败 {count} 次', + apiKeyLatency: '{ms} ms', + apiKeyHTTPStatus: 'HTTP {status}', + apiKeyFrozenUntil: '冻结至 {time}', + apiKeyLastChecked: '检查于 {time}', + apiKeyNotTested: '尚未测试', + auditTestInput: '审计试跑输入', + auditTestInputHint: '可填写提示词并上传或粘贴图片;图片以 base64 发送,不会保存文件。', + auditTestPromptPlaceholder: '输入要测试的用户提示词;留空时仅测试 Key 可用性。', + auditTestImages: '测试图片', + auditTestImagesHint: '支持上传、拖拽或粘贴图片,最多 4 张,每张不超过 8MB。', + addAuditTestImage: '添加图片', + clearAuditTest: '清空试跑', + auditTestImageLimit: '最多只能添加 {count} 张测试图片', + auditTestImageTooLarge: '单张测试图片不能超过 8MB', + auditTestImageReadFailed: '读取测试图片失败', + auditTestResult: '审计试跑结果', + auditTestHighest: '最高分类 {category},分数 {score}', + auditTestComposite: '综合评分', + auditTestFlagged: '命中阈值', + auditTestPassed: '未命中', + notConfigured: '未配置', + clearApiKey: '清除已保存 Key', + keepApiKey: '保留已保存 Key', + timeoutMs: 'HTTP 超时 (ms)', + retryCount: '失败重试次数', + sampleRate: '采样率', + recordNonHits: '记录未命中输入', + recordNonHitsHint: '开启后会记录抽样但未命中的请求摘要,摘要会先脱敏再入库。', + preHashCheck: '启用前置哈希比对', + preHashCheckHint: '异步审核命中过的输入哈希会被前置拦截;该拦截不发送邮件,也不累计封禁次数。', + flaggedHashCount: '当前哈希集合数量:{count} 个', + flaggedHashHint: '哈希永久保存在 Redis 集合中;可粘贴完整 64 位哈希删除误拦截项,或一键清空全部风险哈希。', + flaggedHashPlaceholder: '粘贴完整 64 位输入哈希', + deleteFlaggedHash: '删除指定哈希', + clearFlaggedHashes: '一键清空', + clearFlaggedHashesConfirm: '确定要清空全部风险输入哈希吗?此操作不会删除审核记录,但会取消所有历史哈希拦截。', + flaggedHashDeleted: '风险哈希已删除', + flaggedHashNotFound: '该风险哈希不存在', + flaggedHashDeleteFailed: '删除风险哈希失败', + flaggedHashesCleared: '已清空 {count} 个风险哈希', + flaggedHashesClearFailed: '清空风险哈希失败', + workerCount: 'Worker 数', + queueSize: '异步队列大小', + blockStatus: '拦截 HTTP 状态码', + blockMessage: '自定义拦截提示', + emailOnHit: '命中后发送邮件', + emailOnHitHint: '开启后每次达到阈值都会向用户发送风控提醒邮件;自动封禁通知始终发送。', + autoBan: '自动封禁用户', + autoBanHint: '命中次数达到阈值后将禁用用户账号、刷新认证缓存并发送封禁通知邮件。', + banThreshold: '封禁触发次数', + violationWindowHours: '累计窗口(小时)', + hitRetentionDays: '命中记录保留(天)', + nonHitRetentionDays: '未命中记录保留(天,最多 3 天)', + violationCount: '{count} 次', + emailSent: '已发邮件', + emailNotSent: '未发邮件', + autoBanned: '已封禁', + unbanUser: '解封', + unbanSuccess: '用户已解封', + unbanFailed: '解封用户失败', + inputDetailTitle: '输入摘要详情', + inputDetailContent: '完整内容', + queueDelay: '排队 {ms} ms', + allGroups: '全部分组', + allGroupsHint: '当前审计全部分组', + selectedGroupsHint: '当前审计指定分组', + groupScope: '审计分组', + groupScopeHint: '开启右侧开关表示全部分组,关闭后选择指定分组。', + selectedGroups: '指定分组', + searchGroups: '搜索分组名称或平台', + noGroups: '暂无可用分组', + emptyLogs: '暂无审核记录', + workerStatus: 'Worker 运行状态', + workerStatusHint: '异步观察任务的队列和 worker 池状态。', + workerPool: 'Worker 池', + workerPoolMeta: '{active} 个处理中,{idle} 个空闲可用,共 {total} 个', + queueUsage: '队列占用', + activeWorkers: '处理中', + idleWorkers: '空闲可用', + workerActive: '正在处理异步审计任务', + workerIdle: '已启动,当前空闲可用', + workerDisabled: '风控或内容审计未启用', + processed: '已处理', + droppedErrors: '丢弃/异常', + autoRefresh: '每 15 秒自动刷新', + lastCleanup: '上次清理:{time}', + cleanupStats: '上次清理删除命中 {hit} 条,未命中 {nonHit} 条', + riskSwitchOff: '系统开关关闭', + tabs: { + basic: '基础', + scope: '审计范围', + runtime: '运行队列', + response: '命中通知', + retention: '日志保留', + }, + overview: { + status: '运行状态', + enabled: '已启用', + disabled: '未启用', + apiKey: 'API Key', + groupScope: '审计范围', + logs: '审核记录', + currentFilter: '当前筛选结果', + }, + filters: { + search: '按用户/Key/摘要搜索', + from: '开始时间', + to: '结束时间', + allGroups: '全部分组', + allEndpoints: '全部端点', + }, + table: { + time: '时间', + group: '分组', + user: '用户', + apiKey: 'API Key', + endpoint: '端点', + result: '结果', + highest: '最高分', + actionMeta: '处置', + latency: '上游耗时', + input: '输入摘要', + }, + result: { + all: '全部结果', + hit: '命中', + blocked: '已拦截', + pass: '未命中', + error: '异常', + }, + action: { + block: '拦截', + error: '异常', + }, + }, + // Channel Monitor channelMonitor: { title: '渠道监控', @@ -5025,6 +5223,13 @@ export default { enabled: '启用可用渠道', enabledHint: '关闭后用户端侧边栏入口隐藏,接口返回空数组。', }, + riskControl: { + title: '风控中心', + description: '启用内容审计菜单和全端点请求审核入口。默认关闭。', + configureLink: '前往 风控中心 配置内容审计', + enabled: '启用风控中心', + enabledHint: '关闭后管理员侧边栏入口隐藏,网关内容审计不会执行。', + }, affiliate: { title: '邀请返利', description: '老用户邀请新用户注册,新用户充值后老用户按比例获得返利额度。默认关闭。', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 238f6a71..36d92289 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -505,6 +505,19 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.settings.description' } }, + { + path: '/admin/risk-control', + name: 'AdminRiskControl', + component: () => import('@/views/admin/RiskControlView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Risk Control', + titleKey: 'admin.riskControl.title', + descriptionKey: 'admin.riskControl.description', + requiresRiskControl: true + } + }, { path: '/admin/usage', name: 'AdminUsage', @@ -747,6 +760,14 @@ router.beforeEach((to, _from, next) => { } } + if (to.meta.requiresRiskControl) { + const riskControlEnabled = appStore.cachedPublicSettings?.risk_control_enabled === true + if (!riskControlEnabled) { + next(authStore.isAdmin ? '/admin/settings' : '/dashboard') + return + } + } + // 简易模式下限制访问某些页面 if (authStore.isSimpleMode) { const restrictedPaths = [ diff --git a/frontend/src/router/meta.d.ts b/frontend/src/router/meta.d.ts index 7b2777c2..5c468016 100644 --- a/frontend/src/router/meta.d.ts +++ b/frontend/src/router/meta.d.ts @@ -49,6 +49,12 @@ declare module 'vue-router' { */ requiresPayment?: boolean + /** + * 是否要求风控中心功能开关已启用 + * @default false + */ + requiresRiskControl?: boolean + /** * i18n key for the page title */ diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 876ab5c0..b4329e7f 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -355,6 +355,7 @@ export const useAppStore = defineStore('app', () => { channel_monitor_enabled: true, channel_monitor_default_interval_seconds: 60, available_channels_enabled: false, + risk_control_enabled: false, affiliate_enabled: false, } } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 79530c99..727b9436 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -197,6 +197,7 @@ export interface PublicSettings { home_content: string hide_ccs_import_button: boolean payment_enabled: boolean + risk_control_enabled: boolean table_default_page_size: number table_page_size_options: number[] custom_menu_items: CustomMenuItem[] diff --git a/frontend/src/utils/featureFlags.ts b/frontend/src/utils/featureFlags.ts index e0668694..403e7cdc 100644 --- a/frontend/src/utils/featureFlags.ts +++ b/frontend/src/utils/featureFlags.ts @@ -109,6 +109,11 @@ export const FeatureFlags = { mode: 'opt-out', label: 'Payment', }), + riskControl: defineFlag({ + key: 'risk_control_enabled', + mode: 'opt-in', + label: 'Risk Control', + }), affiliate: defineFlag({ key: 'affiliate_enabled', mode: 'opt-in', diff --git a/frontend/src/views/admin/RiskControlView.vue b/frontend/src/views/admin/RiskControlView.vue new file mode 100644 index 00000000..0041cd8e --- /dev/null +++ b/frontend/src/views/admin/RiskControlView.vue @@ -0,0 +1,1574 @@ + + + diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index e32dd30e..e47392a4 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -4264,6 +4264,39 @@ +
+
+

+ {{ t('admin.settings.features.riskControl.title') }} +

+

+ {{ t('admin.settings.features.riskControl.description') }} +

+

+ + {{ t('admin.settings.features.riskControl.configureLink') }} + + +

+
+
+
+
+ +

+ {{ t('admin.settings.features.riskControl.enabledHint') }} +

+
+ +
+
+
+
@@ -5828,6 +5861,7 @@ const form = reactive({ backend_mode_enabled: false, hide_ccs_import_button: false, payment_enabled: false, + risk_control_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, @@ -6863,6 +6897,7 @@ async function saveSettings() { form.enable_anthropic_cache_ttl_1h_injection, // Payment configuration payment_enabled: form.payment_enabled, + risk_control_enabled: form.risk_control_enabled, payment_min_amount: Number(form.payment_min_amount) || 0, payment_max_amount: Number(form.payment_max_amount) || 0, payment_daily_limit: Number(form.payment_daily_limit) || 0, diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 78ba4b9d..3601c666 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -186,6 +186,7 @@ import TurnstileWidget from '@/components/TurnstileWidget.vue' import { useAuthStore, useAppStore } from '@/stores' import { getPublicSettings, isTotp2FARequired, isWeChatWebOAuthEnabled } from '@/api/auth' import type { TotpLoginResponse } from '@/types' +import { extractI18nErrorMessage } from '@/utils/apiError' import { clearAllAffiliateReferralCodes } from '@/utils/oauthAffiliate' const { t } = useI18n() @@ -369,16 +370,7 @@ async function handleLogin(): Promise { turnstileToken.value = '' } - // Handle login error - const err = error as { message?: string; response?: { data?: { detail?: string } } } - - if (err.response?.data?.detail) { - errorMessage.value = err.response.data.detail - } else if (err.message) { - errorMessage.value = err.message - } else { - errorMessage.value = t('auth.loginFailed') - } + errorMessage.value = extractI18nErrorMessage(error, t, 'auth.errors', t('auth.loginFailed')) // Also show error toast appStore.showError(errorMessage.value)