diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go index 6f0f2aab..defcd29d 100644 --- a/backend/internal/handler/admin/content_moderation_handler.go +++ b/backend/internal/handler/admin/content_moderation_handler.go @@ -20,34 +20,35 @@ func NewContentModerationHandler(svc *service.ContentModerationService) *Content } type contentModerationConfigRequest struct { - Enabled *bool `json:"enabled"` - Mode *string `json:"mode"` - BaseURL *string `json:"base_url"` - Model *string `json:"model"` - APIKey *string `json:"api_key"` - APIKeys *[]string `json:"api_keys"` - APIKeysMode string `json:"api_keys_mode"` - DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` - ClearAPIKey bool `json:"clear_api_key"` - TimeoutMS *int `json:"timeout_ms"` - SampleRate *int `json:"sample_rate"` - AllGroups *bool `json:"all_groups"` - GroupIDs *[]int64 `json:"group_ids"` - RecordNonHits *bool `json:"record_non_hits"` - WorkerCount *int `json:"worker_count"` - QueueSize *int `json:"queue_size"` - BlockStatus *int `json:"block_status"` - BlockMessage *string `json:"block_message"` - EmailOnHit *bool `json:"email_on_hit"` - AutoBanEnabled *bool `json:"auto_ban_enabled"` - BanThreshold *int `json:"ban_threshold"` - ViolationWindowHours *int `json:"violation_window_hours"` - RetryCount *int `json:"retry_count"` - HitRetentionDays *int `json:"hit_retention_days"` - NonHitRetentionDays *int `json:"non_hit_retention_days"` - PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` - BlockedKeywords *[]string `json:"blocked_keywords"` - KeywordBlockingMode *string `json:"keyword_blocking_mode"` + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + APIKeysMode string `json:"api_keys_mode"` + DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` + BlockedKeywords *[]string `json:"blocked_keywords"` + KeywordBlockingMode *string `json:"keyword_blocking_mode"` + ModelFilter *service.ContentModerationModelFilter `json:"model_filter"` } type contentModerationAPIKeyTestRequest struct { @@ -107,6 +108,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) { PreHashCheckEnabled: req.PreHashCheckEnabled, BlockedKeywords: req.BlockedKeywords, KeywordBlockingMode: req.KeywordBlockingMode, + ModelFilter: req.ModelFilter, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go index 2d066298..b5a889e1 100644 --- a/backend/internal/service/content_moderation.go +++ b/backend/internal/service/content_moderation.go @@ -44,6 +44,10 @@ const ( ContentModerationKeywordModeKeywordAndAPI = "keyword_and_api" ContentModerationKeywordModeAPIOnly = "api_only" + ContentModerationModelFilterAll = "all" + ContentModerationModelFilterInclude = "include" + ContentModerationModelFilterExclude = "exclude" + ContentModerationProtocolAnthropicMessages = "anthropic_messages" ContentModerationProtocolOpenAIResponses = "openai_responses" ContentModerationProtocolOpenAIChat = "openai_chat_completions" @@ -80,6 +84,8 @@ const ( maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024 maxContentModerationBlockedKeywords = 10000 maxContentModerationBlockedKeywordRunes = 200 + maxContentModerationModelFilterModels = 1000 + maxContentModerationModelFilterRunes = 200 contentModerationCleanupInterval = 24 * time.Hour contentModerationCleanupTimeout = 30 * time.Minute @@ -127,32 +133,33 @@ func ContentModerationCategories() []string { } type ContentModerationConfig struct { - Enabled bool `json:"enabled"` - Mode string `json:"mode"` - BaseURL string `json:"base_url"` - Model string `json:"model"` - APIKey string `json:"api_key,omitempty"` - APIKeys []string `json:"api_keys,omitempty"` - TimeoutMS int `json:"timeout_ms"` - SampleRate int `json:"sample_rate"` - AllGroups bool `json:"all_groups"` - GroupIDs []int64 `json:"group_ids"` - RecordNonHits bool `json:"record_non_hits"` - Thresholds map[string]float64 `json:"thresholds"` - WorkerCount int `json:"worker_count"` - QueueSize int `json:"queue_size"` - BlockStatus int `json:"block_status"` - BlockMessage string `json:"block_message"` - EmailOnHit bool `json:"email_on_hit"` - AutoBanEnabled bool `json:"auto_ban_enabled"` - BanThreshold int `json:"ban_threshold"` - ViolationWindowHours int `json:"violation_window_hours"` - RetryCount int `json:"retry_count"` - HitRetentionDays int `json:"hit_retention_days"` - NonHitRetentionDays int `json:"non_hit_retention_days"` - PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` - BlockedKeywords []string `json:"blocked_keywords"` - KeywordBlockingMode string `json:"keyword_blocking_mode"` + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKey string `json:"api_key,omitempty"` + APIKeys []string `json:"api_keys,omitempty"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + Thresholds map[string]float64 `json:"thresholds"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` + BlockedKeywords []string `json:"blocked_keywords"` + KeywordBlockingMode string `json:"keyword_blocking_mode"` + ModelFilter ContentModerationModelFilter `json:"model_filter"` } type ContentModerationConfigView struct { @@ -184,6 +191,7 @@ type ContentModerationConfigView struct { PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` BlockedKeywords []string `json:"blocked_keywords"` KeywordBlockingMode string `json:"keyword_blocking_mode"` + ModelFilter ContentModerationModelFilter `json:"model_filter"` } type ContentModerationAPIKeyStatus struct { @@ -227,34 +235,40 @@ type ContentModerationTestAuditResult struct { } type UpdateContentModerationConfigInput struct { - Enabled *bool `json:"enabled"` - Mode *string `json:"mode"` - BaseURL *string `json:"base_url"` - Model *string `json:"model"` - APIKey *string `json:"api_key"` - APIKeys *[]string `json:"api_keys"` - APIKeysMode string `json:"api_keys_mode"` - DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` - ClearAPIKey bool `json:"clear_api_key"` - TimeoutMS *int `json:"timeout_ms"` - SampleRate *int `json:"sample_rate"` - AllGroups *bool `json:"all_groups"` - GroupIDs *[]int64 `json:"group_ids"` - RecordNonHits *bool `json:"record_non_hits"` - WorkerCount *int `json:"worker_count"` - QueueSize *int `json:"queue_size"` - BlockStatus *int `json:"block_status"` - BlockMessage *string `json:"block_message"` - EmailOnHit *bool `json:"email_on_hit"` - AutoBanEnabled *bool `json:"auto_ban_enabled"` - BanThreshold *int `json:"ban_threshold"` - ViolationWindowHours *int `json:"violation_window_hours"` - RetryCount *int `json:"retry_count"` - HitRetentionDays *int `json:"hit_retention_days"` - NonHitRetentionDays *int `json:"non_hit_retention_days"` - PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` - BlockedKeywords *[]string `json:"blocked_keywords"` - KeywordBlockingMode *string `json:"keyword_blocking_mode"` + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + APIKeysMode string `json:"api_keys_mode"` + DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` + BlockedKeywords *[]string `json:"blocked_keywords"` + KeywordBlockingMode *string `json:"keyword_blocking_mode"` + ModelFilter *ContentModerationModelFilter `json:"model_filter"` +} + +type ContentModerationModelFilter struct { + Type string `json:"type"` + Models []string `json:"models"` } type ContentModerationCheckInput struct { @@ -581,6 +595,9 @@ func (s *ContentModerationService) UpdateConfig(ctx context.Context, input Updat if input.KeywordBlockingMode != nil { cfg.KeywordBlockingMode = strings.TrimSpace(*input.KeywordBlockingMode) } + if input.ModelFilter != nil { + cfg.ModelFilter = *input.ModelFilter + } if input.AllGroups != nil { cfg.AllGroups = *input.AllGroups } @@ -719,7 +736,8 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "error", err) return allow, nil } - inScope := cfg.includesGroup(input.GroupID) + inGroupScope := cfg.includesGroup(input.GroupID) + inModelScope := cfg.includesModel(input.Model) slog.Info("content_moderation.config_loaded", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -733,7 +751,10 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "mode", cfg.Mode, "all_groups", cfg.AllGroups, "configured_group_ids", cfg.GroupIDs, - "in_scope", inScope, + "in_group_scope", inGroupScope, + "model_filter_type", cfg.ModelFilter.Type, + "configured_models", cfg.ModelFilter.Models, + "in_model_scope", inModelScope, "sample_rate", cfg.SampleRate, "api_key_count", len(cfg.apiKeys()), "pre_hash_check_enabled", cfg.PreHashCheckEnabled, @@ -756,7 +777,7 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "protocol", input.Protocol) return allow, nil } - if !inScope { + if !inGroupScope { slog.Info("content_moderation.skip_group_out_of_scope", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -768,6 +789,19 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "configured_group_ids", cfg.GroupIDs) return allow, nil } + if !inModelScope { + slog.Info("content_moderation.skip_model_out_of_scope", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "model", input.Model, + "model_filter_type", cfg.ModelFilter.Type, + "configured_models", cfg.ModelFilter.Models) + return allow, nil + } content := ExtractContentModerationInput(input.Protocol, input.Body) if content.IsEmpty() { slog.Info("content_moderation.skip_empty_input", @@ -1025,6 +1059,9 @@ func (s *ContentModerationService) worker(id int) { if !cfg.includesGroup(task.input.GroupID) { return } + if !cfg.includesModel(task.input.Model) { + return + } s.asyncActive.Add(1) defer s.asyncActive.Add(-1) queueDelay := int(time.Since(task.enqueuedAt).Milliseconds()) @@ -1270,6 +1307,9 @@ func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *Cont if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 { return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间") } + if cfg.ModelFilter.Type != ContentModerationModelFilterAll && len(cfg.ModelFilter.Models) == 0 { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODEL_FILTER", "指定或排除模型时至少需要配置 1 个模型") + } if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil { for _, groupID := range cfg.GroupIDs { if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil { @@ -1590,6 +1630,10 @@ func defaultContentModerationConfig() *ContentModerationConfig { PreHashCheckEnabled: false, BlockedKeywords: []string{}, KeywordBlockingMode: ContentModerationKeywordModeKeywordAndAPI, + ModelFilter: ContentModerationModelFilter{ + Type: ContentModerationModelFilterAll, + Models: []string{}, + }, } } @@ -1670,6 +1714,7 @@ func (cfg *ContentModerationConfig) normalize() { cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds) cfg.BlockedKeywords = normalizeBlockedKeywords(cfg.BlockedKeywords) cfg.KeywordBlockingMode = normalizeKeywordBlockingMode(cfg.KeywordBlockingMode) + cfg.ModelFilter = normalizeContentModerationModelFilter(cfg.ModelFilter) } func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { @@ -1687,6 +1732,21 @@ func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { return false } +func (cfg *ContentModerationConfig) includesModel(model string) bool { + if cfg == nil { + return true + } + filter := normalizeContentModerationModelFilter(cfg.ModelFilter) + switch filter.Type { + case ContentModerationModelFilterInclude: + return contentModerationModelListContains(filter.Models, model) + case ContentModerationModelFilterExclude: + return !contentModerationModelListContains(filter.Models, model) + default: + return true + } +} + func contentModerationLogGroupID(groupID *int64) int64 { if groupID == nil { return 0 @@ -1848,6 +1908,7 @@ func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *Con PreHashCheckEnabled: cfg.PreHashCheckEnabled, BlockedKeywords: append([]string(nil), cfg.BlockedKeywords...), KeywordBlockingMode: cfg.KeywordBlockingMode, + ModelFilter: cloneContentModerationModelFilter(cfg.ModelFilter), } } @@ -2125,6 +2186,73 @@ func normalizeKeywordBlockingMode(mode string) string { } } +func normalizeContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter { + out := ContentModerationModelFilter{ + Type: normalizeContentModerationModelFilterType(filter.Type), + Models: normalizeContentModerationModelNames(filter.Models), + } + if out.Type == ContentModerationModelFilterAll { + out.Models = []string{} + } + return out +} + +func cloneContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter { + normalized := normalizeContentModerationModelFilter(filter) + normalized.Models = append([]string(nil), normalized.Models...) + return normalized +} + +func normalizeContentModerationModelFilterType(filterType string) string { + switch strings.ToLower(strings.TrimSpace(filterType)) { + case ContentModerationModelFilterInclude: + return ContentModerationModelFilterInclude + case ContentModerationModelFilterExclude: + return ContentModerationModelFilterExclude + case ContentModerationModelFilterAll: + return ContentModerationModelFilterAll + default: + return ContentModerationModelFilterAll + } +} + +func normalizeContentModerationModelNames(models []string) []string { + if len(models) == 0 { + return []string{} + } + out := make([]string, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for _, raw := range models { + model := trimRunes(strings.TrimSpace(raw), maxContentModerationModelFilterRunes) + if model == "" { + continue + } + key := strings.ToLower(model) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, model) + if len(out) >= maxContentModerationModelFilterModels { + break + } + } + return out +} + +func contentModerationModelListContains(models []string, model string) bool { + model = strings.ToLower(strings.TrimSpace(model)) + if model == "" { + return false + } + for _, candidate := range models { + if strings.ToLower(strings.TrimSpace(candidate)) == model { + return true + } + } + return false +} + func matchBlockedKeyword(text string, keywords []string) (string, bool) { if text == "" || len(keywords) == 0 { return "", false diff --git a/backend/internal/service/content_moderation_test.go b/backend/internal/service/content_moderation_test.go index 30578ca5..60a99318 100644 --- a/backend/internal/service/content_moderation_test.go +++ b/backend/internal/service/content_moderation_test.go @@ -530,6 +530,147 @@ func TestNormalizeKeywordBlockingMode_UnknownFallsBackToDefault(t *testing.T) { require.Equal(t, ContentModerationKeywordModeAPIOnly, normalizeKeywordBlockingMode("api_only")) } +func TestContentModerationCheck_ModelFilterAllAuditsEveryModel(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterAll} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + for _, model := range []string{"gpt-5.5", "gpt-5.4"} { + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: model, + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + } + require.Len(t, repo.logs, 2) +} + +func TestContentModerationCheck_ModelFilterIncludeOnlyAuditsListedModels(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + + decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.4", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Allowed) + require.False(t, decision.Blocked) + require.Equal(t, ContentModerationActionAllow, decision.Action) + require.Len(t, repo.logs, 1) + require.Equal(t, "gpt-5.5", repo.logs[0].Model) +} + +func TestContentModerationCheck_ModelFilterExcludeSkipsListedModels(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterExclude, Models: []string{"gpt-5.4"}} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + + decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.4", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Allowed) + require.False(t, decision.Blocked) + require.Equal(t, ContentModerationActionAllow, decision.Action) + require.Len(t, repo.logs, 1) + require.Equal(t, "gpt-5.5", repo.logs[0].Model) +} + +func TestContentModerationLoadConfig_LegacyConfigDefaultsModelFilterToAll(t *testing.T) { + raw := `{"enabled":true,"mode":"pre_block","base_url":"https://api.openai.com","model":"omni-moderation-latest","blocked_keywords":["secret-token"]}` + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyContentModerationConfig: raw, + }}, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + cfg, err := svc.loadConfig(context.Background()) + + require.NoError(t, err) + require.Equal(t, ContentModerationModelFilterAll, cfg.ModelFilter.Type) + require.Empty(t, cfg.ModelFilter.Models) + require.True(t, cfg.includesModel("gpt-5.5")) + require.True(t, cfg.includesModel("gpt-5.4")) +} + +func TestContentModerationCheck_ModelFilterUsesRequestedModelNotBodyModel(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"model":"mapped-upstream-model","messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + require.Len(t, repo.logs, 1) + require.Equal(t, "gpt-5.5", repo.logs[0].Model) +} + +func defaultContentModerationModelFilterTestConfig() *ContentModerationConfig { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BlockedKeywords = []string{"secret-token"} + return cfg +} + +func newContentModerationModelFilterTestService(t *testing.T, cfg *ContentModerationConfig) (*ContentModerationService, *contentModerationTestRepo) { + t.Helper() + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + repo := &contentModerationTestRepo{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + return svc, repo +} + func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) { cfg := defaultContentModerationConfig() cfg.APIKeys = []string{"sk-old-a", "sk-old-b"} diff --git a/frontend/src/api/admin/riskControl.ts b/frontend/src/api/admin/riskControl.ts index 4dad1f58..fbba96be 100644 --- a/frontend/src/api/admin/riskControl.ts +++ b/frontend/src/api/admin/riskControl.ts @@ -2,6 +2,12 @@ import { apiClient } from '../client' export type ModerationMode = 'off' | 'observe' | 'pre_block' export type KeywordBlockingMode = 'keyword_only' | 'keyword_and_api' | 'api_only' +export type ContentModerationModelFilterType = 'all' | 'include' | 'exclude' + +export interface ContentModerationModelFilter { + type: ContentModerationModelFilterType + models: string[] +} export interface ContentModerationConfig { enabled: boolean @@ -32,6 +38,7 @@ export interface ContentModerationConfig { pre_hash_check_enabled: boolean blocked_keywords: string[] keyword_blocking_mode: KeywordBlockingMode + model_filter: ContentModerationModelFilter } export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen' @@ -105,6 +112,7 @@ export interface UpdateContentModerationConfig { pre_hash_check_enabled?: boolean blocked_keywords?: string[] keyword_blocking_mode?: KeywordBlockingMode + model_filter?: ContentModerationModelFilter } export interface ContentModerationRuntimeStatus { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1b538b92..468a80ba 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2532,6 +2532,20 @@ export default { selectedGroups: 'Selected Groups', searchGroups: 'Search group name or platform', noGroups: 'No groups available', + modelFilter: 'Model scope', + modelFilterHint: 'Moderate by the client-requested model name; channel model mappings do not change this match.', + modelFilterAll: 'All models', + modelFilterAllDesc: 'All model requests go through content moderation.', + modelFilterInclude: 'Only selected', + modelFilterIncludeDesc: 'Only listed models go through content moderation.', + modelFilterExclude: 'Exclude selected', + modelFilterExcludeDesc: 'Listed models skip content moderation; other models are moderated.', + modelFilterModels: 'Model list', + modelFilterModelCount: '{count} models configured', + modelFilterModelsRequired: 'This model scope requires at least 1 model', + modelFilterAllSummary: 'Applies to all models', + modelFilterIncludeSummary: 'Applies to {count} models', + modelFilterExcludeSummary: 'Excludes {count} models', emptyLogs: 'No audit records', workerStatus: 'Worker Runtime', workerStatusHint: 'Queue and worker pool status for asynchronous observation tasks.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 67285c3f..fed3304a 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2609,6 +2609,20 @@ export default { selectedGroups: '指定分组', searchGroups: '搜索分组名称或平台', noGroups: '暂无可用分组', + modelFilter: '模型范围', + modelFilterHint: '按客户端请求的模型名决定是否执行内容审计,模型映射后仍以请求模型判断。', + modelFilterAll: '所有模型', + modelFilterAllDesc: '所有模型请求都会进入内容审计。', + modelFilterInclude: '仅指定模型', + modelFilterIncludeDesc: '只有列表中的模型会执行内容审计。', + modelFilterExclude: '排除指定模型', + modelFilterExcludeDesc: '列表中的模型跳过内容审计,其余模型执行审计。', + modelFilterModels: '模型列表', + modelFilterModelCount: '已配置 {count} 个模型', + modelFilterModelsRequired: '当前模型范围至少需要配置 1 个模型', + modelFilterAllSummary: '全部模型生效', + modelFilterIncludeSummary: '仅 {count} 个模型生效', + modelFilterExcludeSummary: '排除 {count} 个模型', emptyLogs: '暂无审核记录', workerStatus: 'Worker 运行状态', workerStatusHint: '异步观察任务的队列和 worker 池状态。', diff --git a/frontend/src/views/admin/RiskControlView.vue b/frontend/src/views/admin/RiskControlView.vue index acfcec77..4d56b492 100644 --- a/frontend/src/views/admin/RiskControlView.vue +++ b/frontend/src/views/admin/RiskControlView.vue @@ -145,6 +145,26 @@ +
+
+ + {{ t('admin.riskControl.modelFilter') }} + {{ modelFilterSummary }} +
+
+ + {{ model }} + + + +{{ hiddenModelFilterModelCount }} + +
+
+
@@ -628,6 +648,52 @@

{{ t('admin.riskControl.noGroups') }}

+ +
+
+
+

{{ t('admin.riskControl.modelFilter') }}

+

{{ t('admin.riskControl.modelFilterHint') }}

+
+ + {{ modelFilterSummary }} + +
+ +
+ +
+ +
+ + +

+ {{ t('admin.riskControl.modelFilterModelCount', { count: modelFilterModelCount }) }} +

+
+
@@ -887,11 +953,14 @@ import Icon from '@/components/icons/Icon.vue' import Select from '@/components/common/Select.vue' import Toggle from '@/components/common/Toggle.vue' import Pagination from '@/components/common/Pagination.vue' +import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import { adminAPI } from '@/api/admin' import type { ContentModerationAPIKeyStatus, ContentModerationConfig, ContentModerationLog, + ContentModerationModelFilter, + ContentModerationModelFilterType, ContentModerationRuntimeStatus, ContentModerationTestAuditResult, KeywordBlockingMode, @@ -987,6 +1056,8 @@ const configForm = reactive({ pre_hash_check_enabled: false, blocked_keywords_text: '', keyword_blocking_mode: 'keyword_and_api' as KeywordBlockingMode, + model_filter_type: 'all' as ContentModerationModelFilterType, + model_filter_models: [] as string[], }) const pagination = reactive({ @@ -1038,6 +1109,24 @@ const keywordBlockingModeOptions = computed>(() => [ + { + value: 'all', + label: t('admin.riskControl.modelFilterAll'), + description: t('admin.riskControl.modelFilterAllDesc'), + }, + { + value: 'include', + label: t('admin.riskControl.modelFilterInclude'), + description: t('admin.riskControl.modelFilterIncludeDesc'), + }, + { + value: 'exclude', + label: t('admin.riskControl.modelFilterExclude'), + description: t('admin.riskControl.modelFilterExcludeDesc'), + }, +]) + type KeywordNoticeView = { title: string description: string @@ -1120,6 +1209,22 @@ const groupFilterOptions = computed(() => [ const selectedGroupCount = computed(() => String(configForm.group_ids.length)) +const modelFilterModelCount = computed(() => configForm.model_filter_models.length) + +const modelFilterSummary = computed(() => { + if (configForm.model_filter_type === 'include') { + return t('admin.riskControl.modelFilterIncludeSummary', { count: modelFilterModelCount.value }) + } + if (configForm.model_filter_type === 'exclude') { + return t('admin.riskControl.modelFilterExcludeSummary', { count: modelFilterModelCount.value }) + } + return t('admin.riskControl.modelFilterAllSummary') +}) + +const modelFilterPreviewModels = computed(() => configForm.model_filter_models.slice(0, 6)) + +const hiddenModelFilterModelCount = computed(() => Math.max(0, configForm.model_filter_models.length - modelFilterPreviewModels.value.length)) + const filteredGroups = computed(() => { const keyword = groupSearch.value.trim().toLowerCase() if (!keyword) return groups.value @@ -1238,7 +1343,7 @@ const overviewItems = computed(() => [ key: 'scope', label: t('admin.riskControl.overview.groupScope'), value: configForm.all_groups ? t('admin.riskControl.allGroups') : selectedGroupCount.value, - meta: configForm.all_groups ? t('admin.riskControl.allGroupsHint') : t('admin.riskControl.selectedGroupsHint'), + meta: modelFilterSummary.value, icon: 'users', iconClass: 'bg-violet-50 text-violet-600 dark:bg-violet-900/20 dark:text-violet-300', }, @@ -1342,6 +1447,9 @@ function applyConfig(config: ContentModerationConfig) { configForm.pre_hash_check_enabled = config.pre_hash_check_enabled ?? false configForm.blocked_keywords_text = Array.isArray(config.blocked_keywords) ? config.blocked_keywords.join('\n') : '' configForm.keyword_blocking_mode = normalizeKeywordBlockingMode(config.keyword_blocking_mode) + const modelFilter = normalizeModelFilter(config.model_filter) + configForm.model_filter_type = modelFilter.type + configForm.model_filter_models = modelFilter.models } async function loadAll() { @@ -1388,6 +1496,11 @@ async function loadStatus(silent = true) { async function saveConfig() { saving.value = true try { + const modelFilterPayload = buildModelFilterPayload() + if (modelFilterPayload.type !== 'all' && modelFilterPayload.models.length === 0) { + appStore.showError(t('admin.riskControl.modelFilterModelsRequired')) + return + } const payload: UpdateContentModerationConfig = { enabled: configForm.enabled, mode: configForm.mode, @@ -1413,6 +1526,7 @@ async function saveConfig() { pre_hash_check_enabled: configForm.pre_hash_check_enabled, blocked_keywords: blockedKeywordList.value, keyword_blocking_mode: configForm.keyword_blocking_mode, + model_filter: modelFilterPayload, } const keys = parseApiKeys(configForm.api_keys_text) if (!payload.clear_api_key && configForm.api_keys_mode === 'replace' && keys.length === 0) { @@ -1568,6 +1682,13 @@ function setAPIKeysMode(mode: APIKeysWriteMode) { } } +function setModelFilterType(type: ContentModerationModelFilterType) { + configForm.model_filter_type = type + if (type === 'all') { + configForm.model_filter_models = [] + } +} + async function testApiKeys(useInputKeys: boolean) { const keys = useInputKeys ? parseApiKeys(configForm.api_keys_text) : [] if (useInputKeys && keys.length === 0) { @@ -1824,6 +1945,49 @@ function normalizeKeywordBlockingMode(value: unknown): KeywordBlockingMode { return 'keyword_and_api' } +function normalizeModelFilter(value: unknown): ContentModerationModelFilter { + if (!value || typeof value !== 'object') { + return { type: 'all', models: [] } + } + const raw = value as Partial + const type = normalizeModelFilterType(raw.type) + const models = type === 'all' ? [] : normalizeModelNames(raw.models) + return { type, models } +} + +function normalizeModelFilterType(value: unknown): ContentModerationModelFilterType { + if (value === 'include' || value === 'exclude' || value === 'all') { + return value + } + return 'all' +} + +function normalizeModelNames(models: unknown): string[] { + if (!Array.isArray(models)) return [] + const seen = new Set() + const out: string[] = [] + for (const item of models) { + const model = String(item ?? '').trim() + if (!model) continue + const key = model.toLowerCase() + if (seen.has(key)) continue + seen.add(key) + out.push(model) + } + return out +} + +function buildModelFilterPayload(): ContentModerationModelFilter { + const type = normalizeModelFilterType(configForm.model_filter_type) + if (type === 'all') { + return { type: 'all', models: [] } + } + return { + type, + models: normalizeModelNames(configForm.model_filter_models), + } +} + function parseBlockedKeywords(value: string): string[] { const seen = new Set() const out: string[] = [] diff --git a/frontend/src/views/admin/__tests__/RiskControlView.spec.ts b/frontend/src/views/admin/__tests__/RiskControlView.spec.ts new file mode 100644 index 00000000..b528a278 --- /dev/null +++ b/frontend/src/views/admin/__tests__/RiskControlView.spec.ts @@ -0,0 +1,227 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { defineComponent, h } from 'vue' +import { flushPromises, mount } from '@vue/test-utils' +import type { DOMWrapper, VueWrapper } from '@vue/test-utils' + +import RiskControlView from '../RiskControlView.vue' +import type { ContentModerationConfig, UpdateContentModerationConfig } from '@/api/admin/riskControl' + +const { + getConfig, + updateConfig, + getStatus, + listLogs, + getGroups, + showError, + showSuccess, +} = vi.hoisted(() => ({ + getConfig: vi.fn(), + updateConfig: vi.fn(), + getStatus: vi.fn(), + listLogs: vi.fn(), + getGroups: vi.fn(), + showError: vi.fn(), + showSuccess: vi.fn(), +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + riskControl: { + getConfig, + updateConfig, + getStatus, + listLogs, + testAPIKeys: vi.fn(), + deleteFlaggedHash: vi.fn(), + clearFlaggedHashes: vi.fn(), + unbanUser: vi.fn(), + }, + groups: { + getAll: getGroups, + }, + }, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError, + showSuccess, + }), +})) + +vi.mock('@/utils/apiError', () => ({ + extractApiErrorMessage: (_err: unknown, fallback: string) => fallback, +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => + key.replace(/\{(\w+)\}/g, (_, token) => String(params?.[token] ?? `{${token}}`)), + }), + } +}) + +const baseConfig = (): ContentModerationConfig => ({ + enabled: true, + mode: 'pre_block', + base_url: 'https://api.openai.com', + model: 'omni-moderation-latest', + api_key_configured: false, + api_key_masked: '', + api_key_count: 0, + api_key_masks: [], + api_key_statuses: [], + timeout_ms: 3000, + sample_rate: 100, + all_groups: true, + group_ids: [], + record_non_hits: false, + worker_count: 4, + queue_size: 32768, + block_status: 403, + block_message: '内容审计命中风险规则,请调整输入后重试', + email_on_hit: true, + auto_ban_enabled: true, + ban_threshold: 10, + violation_window_hours: 720, + retry_count: 2, + hit_retention_days: 180, + non_hit_retention_days: 3, + pre_hash_check_enabled: false, + blocked_keywords: [], + keyword_blocking_mode: 'keyword_and_api', + model_filter: { + type: 'all', + models: [], + }, +}) + +const runtimeStatus = () => ({ + enabled: true, + risk_control_enabled: true, + mode: 'pre_block', + worker_count: 4, + max_workers: 32, + active_workers: 0, + idle_workers: 4, + queue_size: 32768, + queue_length: 0, + queue_usage_percent: 0, + enqueued: 0, + dropped: 0, + processed: 0, + errors: 0, + api_key_statuses: [], + flagged_hash_count: 0, + last_cleanup_deleted_hit: 0, + last_cleanup_deleted_non_hit: 0, +}) + +const AppLayoutStub = { template: '
' } +const BaseDialogStub = defineComponent({ + props: { + show: { + type: Boolean, + default: false, + }, + }, + template: '
', +}) +const ModelWhitelistSelectorStub = defineComponent({ + props: { + modelValue: { + type: Array, + default: () => [], + }, + }, + emits: ['update:modelValue'], + setup(props, { emit }) { + const onInput = (event: Event) => { + const value = (event.target as HTMLInputElement).value + emit( + 'update:modelValue', + value + .split(/[,\n]/) + .map((item) => item.trim()) + .filter(Boolean) + ) + } + return () => + h('input', { + 'data-test': 'model-filter-input', + value: (props.modelValue as string[]).join('\n'), + onInput, + }) + }, +}) + +function findButtonByText(wrapper: VueWrapper, text: string): DOMWrapper { + const button = wrapper.findAll('button').find((item) => item.text().includes(text)) + if (!button) { + throw new Error(`button not found: ${text}`) + } + return button +} + +describe('admin RiskControlView', () => { + beforeEach(() => { + getConfig.mockReset() + updateConfig.mockReset() + getStatus.mockReset() + listLogs.mockReset() + getGroups.mockReset() + showError.mockReset() + showSuccess.mockReset() + + getConfig.mockResolvedValue(baseConfig()) + getStatus.mockResolvedValue(runtimeStatus()) + listLogs.mockResolvedValue({ items: [], total: 0, page: 1, page_size: 20, pages: 1 }) + getGroups.mockResolvedValue([]) + updateConfig.mockImplementation(async (payload: UpdateContentModerationConfig) => ({ + ...baseConfig(), + ...payload, + model_filter: payload.model_filter ?? baseConfig().model_filter, + api_key_configured: false, + api_key_masked: '', + api_key_count: 0, + api_key_masks: [], + api_key_statuses: [], + })) + }) + + it('saves the selected model filter mode and models', async () => { + const wrapper = mount(RiskControlView, { + global: { + stubs: { + AppLayout: AppLayoutStub, + BaseDialog: BaseDialogStub, + Icon: true, + Select: true, + Toggle: true, + Pagination: true, + ModelWhitelistSelector: ModelWhitelistSelectorStub, + }, + }, + }) + + await flushPromises() + + await findButtonByText(wrapper, 'admin.riskControl.openSettings').trigger('click') + await findButtonByText(wrapper, 'admin.riskControl.tabs.scope').trigger('click') + await findButtonByText(wrapper, 'admin.riskControl.modelFilterInclude').trigger('click') + await wrapper.get('[data-test="model-filter-input"]').setValue('gpt-5.5, gpt-5.4') + await findButtonByText(wrapper, 'admin.riskControl.saveConfig').trigger('click') + await flushPromises() + + expect(updateConfig).toHaveBeenCalledWith(expect.objectContaining({ + model_filter: { + type: 'include', + models: ['gpt-5.5', 'gpt-5.4'], + }, + })) + expect(showError).not.toHaveBeenCalled() + }) +})