package repository import ( "context" "database/sql" "encoding/json" "fmt" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" ) type contentModerationRepository struct { db *sql.DB } func NewContentModerationRepository(db *sql.DB) service.ContentModerationRepository { return &contentModerationRepository{db: db} } func (r *contentModerationRepository) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { if log == nil { return nil } categoryScores, err := json.Marshal(log.CategoryScores) if err != nil { return fmt.Errorf("marshal moderation category scores: %w", err) } thresholdSnapshot, err := json.Marshal(log.ThresholdSnapshot) if err != nil { return fmt.Errorf("marshal moderation thresholds: %w", err) } var userID any if log.UserID != nil { userID = *log.UserID } var apiKeyID any if log.APIKeyID != nil { apiKeyID = *log.APIKeyID } var groupID any if log.GroupID != nil { groupID = *log.GroupID } var latency any if log.UpstreamLatencyMS != nil { latency = *log.UpstreamLatencyMS } err = r.db.QueryRowContext(ctx, ` INSERT INTO content_moderation_logs ( request_id, user_id, user_email, api_key_id, api_key_name, group_id, group_name, endpoint, provider, model, mode, action, flagged, highest_category, highest_score, category_scores, threshold_snapshot, input_excerpt, upstream_latency_ms, error, violation_count, auto_banned, email_sent, queue_delay_ms ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16::jsonb, $17::jsonb, $18, $19, $20, $21, $22, $23, $24 ) RETURNING id, created_at`, log.RequestID, userID, log.UserEmail, apiKeyID, log.APIKeyName, groupID, log.GroupName, log.Endpoint, log.Provider, log.Model, log.Mode, log.Action, log.Flagged, log.HighestCategory, log.HighestScore, string(categoryScores), string(thresholdSnapshot), log.InputExcerpt, latency, log.Error, log.ViolationCount, log.AutoBanned, log.EmailSent, nullableIntPtr(log.QueueDelayMS), ).Scan(&log.ID, &log.CreatedAt) if err != nil { return fmt.Errorf("insert content moderation log: %w", err) } return nil } func (r *contentModerationRepository) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { where, args := buildContentModerationLogWhere(filter) whereSQL := "WHERE " + strings.Join(where, " AND ") var total int64 if err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM content_moderation_logs l "+whereSQL, args...).Scan(&total); err != nil { return nil, nil, fmt.Errorf("count content moderation logs: %w", err) } params := filter.Pagination if params.Page <= 0 { params.Page = 1 } if params.PageSize <= 0 { params.PageSize = 20 } if params.PageSize > 100 { params.PageSize = 100 } queryArgs := append([]any{}, args...) queryArgs = append(queryArgs, params.Limit(), params.Offset()) rows, err := r.db.QueryContext(ctx, ` SELECT l.id, l.request_id, l.user_id, l.user_email, l.api_key_id, l.api_key_name, l.group_id, l.group_name, l.endpoint, l.provider, l.model, l.mode, l.action, l.flagged, l.highest_category, l.highest_score, l.category_scores, l.threshold_snapshot, l.input_excerpt, l.upstream_latency_ms, l.error, l.violation_count, l.auto_banned, l.email_sent, COALESCE(u.status, ''), l.queue_delay_ms, l.created_at FROM content_moderation_logs l LEFT JOIN users u ON u.id = l.user_id `+whereSQL+` ORDER BY l.created_at DESC, l.id DESC LIMIT $`+fmt.Sprint(len(queryArgs)-1)+` OFFSET $`+fmt.Sprint(len(queryArgs)), queryArgs..., ) if err != nil { return nil, nil, fmt.Errorf("list content moderation logs: %w", err) } defer func() { _ = rows.Close() }() items := make([]service.ContentModerationLog, 0) for rows.Next() { var item service.ContentModerationLog var userID, apiKeyID, groupID, latency, queueDelay sql.NullInt64 var scoresRaw, thresholdsRaw []byte if err := rows.Scan( &item.ID, &item.RequestID, &userID, &item.UserEmail, &apiKeyID, &item.APIKeyName, &groupID, &item.GroupName, &item.Endpoint, &item.Provider, &item.Model, &item.Mode, &item.Action, &item.Flagged, &item.HighestCategory, &item.HighestScore, &scoresRaw, &thresholdsRaw, &item.InputExcerpt, &latency, &item.Error, &item.ViolationCount, &item.AutoBanned, &item.EmailSent, &item.UserStatus, &queueDelay, &item.CreatedAt, ); err != nil { return nil, nil, fmt.Errorf("scan content moderation log: %w", err) } if userID.Valid { v := userID.Int64 item.UserID = &v } if apiKeyID.Valid { v := apiKeyID.Int64 item.APIKeyID = &v } if groupID.Valid { v := groupID.Int64 item.GroupID = &v } if latency.Valid { v := int(latency.Int64) item.UpstreamLatencyMS = &v } if queueDelay.Valid { v := int(queueDelay.Int64) item.QueueDelayMS = &v } item.CategoryScores = map[string]float64{} _ = json.Unmarshal(scoresRaw, &item.CategoryScores) item.ThresholdSnapshot = map[string]float64{} _ = json.Unmarshal(thresholdsRaw, &item.ThresholdSnapshot) items = append(items, item) } if err := rows.Err(); err != nil { return nil, nil, fmt.Errorf("iterate content moderation logs: %w", err) } return items, paginationResultFromTotal(total, params), nil } func (r *contentModerationRepository) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { if userID <= 0 { return 0, nil } var count int err := r.db.QueryRowContext(ctx, ` WITH last_auto_ban AS ( SELECT MAX(created_at) AS at FROM content_moderation_logs WHERE user_id = $1 AND auto_banned = TRUE ) SELECT COUNT(*) FROM content_moderation_logs WHERE user_id = $1 AND flagged = TRUE AND action <> 'hash_block' AND created_at >= $2 AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz) `, userID, since).Scan(&count) if err != nil { return 0, fmt.Errorf("count user content moderation flagged logs: %w", err) } return count, nil } func (r *contentModerationRepository) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*service.ContentModerationCleanupResult, error) { result := &service.ContentModerationCleanupResult{FinishedAt: time.Now()} if r == nil || r.db == nil { return result, nil } hitExec, err := r.db.ExecContext(ctx, ` DELETE FROM content_moderation_logs WHERE flagged = TRUE AND created_at < $1 `, hitBefore) if err != nil { return nil, fmt.Errorf("delete expired hit content moderation logs: %w", err) } result.DeletedHit, _ = hitExec.RowsAffected() nonHitExec, err := r.db.ExecContext(ctx, ` DELETE FROM content_moderation_logs WHERE flagged = FALSE AND created_at < $1 `, nonHitBefore) if err != nil { return nil, fmt.Errorf("delete expired non-hit content moderation logs: %w", err) } result.DeletedNonHit, _ = nonHitExec.RowsAffected() result.FinishedAt = time.Now() return result, nil } func nullableIntPtr(value *int) any { if value == nil { return nil } return *value } func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ([]string, []any) { where := []string{"l.id IS NOT NULL"} args := make([]any, 0) add := func(expr string, value any) { args = append(args, value) where = append(where, fmt.Sprintf(expr, len(args))) } switch strings.ToLower(strings.TrimSpace(filter.Result)) { case "hit", "flagged": where = append(where, "l.flagged = TRUE") case "blocked", "block": where = append(where, "l.action IN ('block', 'keyword_block', 'hash_block')") case "pass", "allow": where = append(where, "l.flagged = FALSE AND l.error = ''") case "error": where = append(where, "l.error <> ''") } if filter.GroupID != nil { add("l.group_id = $%d", *filter.GroupID) } if endpoint := strings.TrimSpace(filter.Endpoint); endpoint != "" { add("l.endpoint = $%d", endpoint) } if search := strings.TrimSpace(filter.Search); search != "" { like := "%" + search + "%" args = append(args, like, like, like, like, like) idx := len(args) - 4 where = append(where, fmt.Sprintf("(l.request_id ILIKE $%d OR l.user_email ILIKE $%d OR l.api_key_name ILIKE $%d OR l.model ILIKE $%d OR l.input_excerpt ILIKE $%d)", idx, idx+1, idx+2, idx+3, idx+4)) } if filter.From != nil && !filter.From.IsZero() { add("l.created_at >= $%d", *filter.From) } if filter.To != nil && !filter.To.IsZero() { add("l.created_at <= $%d", *filter.To) } return where, args }