package service import ( "context" "database/sql" "encoding/json" "errors" "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/redis/go-redis/v9" ) var ( ErrRiskAccountNotFound = infraerrors.NotFound("RISK_ACCOUNT_NOT_FOUND", "risk account not found") ErrRiskLevelInvalid = infraerrors.BadRequest("RISK_LEVEL_INVALID", "risk level must be LOW, MEDIUM, or HIGH") ) type RiskRepository interface { UpsertBehaviorHour(ctx context.Context, accountID int64, hour time.Time, delta RiskBehaviorHourDelta) error GetRiskSummary(ctx context.Context) (*RiskSummary, error) ListRiskAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) GetRiskAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error GetOrCreateRiskScore(ctx context.Context, accountID int64) (*RiskScoreRecord, error) } type pgRiskRepository struct { db *sql.DB settingRepo SettingRepository redis *redis.Client } func NewRiskRepository(db *sql.DB, settingRepo SettingRepository, redisClient *redis.Client) RiskRepository { return &pgRiskRepository{ db: db, settingRepo: settingRepo, redis: redisClient, } } const riskUpsertBehaviorHourSQL = ` INSERT INTO account_behavior_hourly ( account_id, hour_bucket, api_call_count, stream_count, total_input_tokens, total_output_tokens, total_duration_ms, p50_duration_ms, created_at, updated_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, NOW(), NOW() ) ON CONFLICT (account_id, hour_bucket) DO UPDATE SET api_call_count = account_behavior_hourly.api_call_count + EXCLUDED.api_call_count, stream_count = account_behavior_hourly.stream_count + EXCLUDED.stream_count, total_input_tokens = account_behavior_hourly.total_input_tokens + EXCLUDED.total_input_tokens, total_output_tokens = account_behavior_hourly.total_output_tokens + EXCLUDED.total_output_tokens, total_duration_ms = account_behavior_hourly.total_duration_ms + EXCLUDED.total_duration_ms, p50_duration_ms = COALESCE(EXCLUDED.p50_duration_ms, account_behavior_hourly.p50_duration_ms), updated_at = NOW() ` const riskSummarySQL = ` SELECT COUNT(rs.account_id)::bigint AS total_accounts, COUNT(*) FILTER (WHERE rs.risk_level = 'LOW')::bigint AS low_count, COUNT(*) FILTER (WHERE rs.risk_level = 'MEDIUM')::bigint AS medium_count, COUNT(*) FILTER (WHERE rs.risk_level = 'HIGH')::bigint AS high_count, COALESCE(AVG(rs.risk_score), 0)::double precision AS average_score, MAX(rs.scored_at) AS last_scored_at FROM account_risk_scores rs JOIN accounts a ON a.id = rs.account_id WHERE a.deleted_at IS NULL AND a.type IN ('oauth', 'setup_token') ` const riskListSQL = ` WITH current_hour AS ( SELECT account_id, api_call_count, total_input_tokens + total_output_tokens AS total_tokens FROM account_behavior_hourly WHERE hour_bucket = date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' ) SELECT COUNT(*) OVER()::bigint AS total_count, rs.account_id, a.name, a.platform, rs.risk_score, rs.risk_level, rs.risk_reasons, rs.feature_vector, rs.idle_override, rs.scored_at, COALESCE(ch.api_call_count, 0)::bigint AS last_hour_calls, COALESCE(ch.total_tokens, 0)::bigint AS last_hour_tokens FROM account_risk_scores rs JOIN accounts a ON a.id = rs.account_id LEFT JOIN current_hour ch ON ch.account_id = rs.account_id WHERE a.deleted_at IS NULL AND a.type IN ('oauth', 'setup_token') AND ($1 = '' OR rs.risk_level = $1) AND ($2 = '' OR a.platform = $2) ORDER BY rs.risk_score DESC, rs.scored_at DESC, rs.account_id DESC LIMIT $3 OFFSET $4 ` const riskDetailSQL = ` SELECT rs.account_id, a.name, a.platform, rs.risk_score, rs.risk_level, rs.risk_reasons, rs.feature_vector, rs.idle_override, rs.scored_at, rs.model_version, bh.hour_bucket, bh.api_call_count, bh.stream_count, bh.total_input_tokens, bh.total_output_tokens, bh.total_duration_ms, bh.p50_duration_ms FROM account_risk_scores rs JOIN accounts a ON a.id = rs.account_id LEFT JOIN account_behavior_hourly bh ON bh.account_id = rs.account_id AND bh.hour_bucket >= (date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC') - INTERVAL '24 hours' WHERE a.deleted_at IS NULL AND a.type IN ('oauth', 'setup_token') AND rs.account_id = $1 ORDER BY bh.hour_bucket DESC NULLS LAST ` const riskOverrideSQL = ` UPDATE account_risk_scores SET risk_level = $2, idle_override = TRUE, risk_reasons = COALESCE(risk_reasons, '{}'::jsonb) || jsonb_build_object( 'manual_override', jsonb_build_object('level', $2, 'reason', $3, 'at', NOW()) ), updated_at = NOW() WHERE account_id = $1 RETURNING updated_at ` const riskScoreRefreshSQL = ` WITH valid_account AS ( SELECT id FROM accounts WHERE id = $1 AND deleted_at IS NULL AND type IN ('oauth', 'setup_token') ), behavior AS ( SELECT va.id AS account_id, COALESCE(SUM(abh.api_call_count), 0)::double precision AS total_calls_24h, COALESCE(AVG(abh.api_call_count), 0)::double precision AS calls_per_hour_24h, COALESCE(SUM(abh.stream_count), 0)::double precision AS stream_calls_24h, COALESCE(SUM(abh.total_input_tokens), 0)::double precision AS total_input_tokens_24h, COALESCE( percentile_cont(0.50) WITHIN GROUP (ORDER BY abh.p50_duration_ms) FILTER (WHERE abh.p50_duration_ms IS NOT NULL), 0 )::double precision AS duration_p50_ms, COALESCE(stddev_pop(abh.api_call_count), 0)::double precision AS hourly_entropy FROM valid_account va LEFT JOIN account_behavior_hourly abh ON abh.account_id = va.id AND abh.hour_bucket >= (date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC') - INTERVAL '24 hours' GROUP BY va.id ), features AS ( SELECT b.account_id, b.calls_per_hour_24h, COALESCE(b.stream_calls_24h / NULLIF(b.total_calls_24h, 0), 0) AS stream_ratio_24h, COALESCE(b.total_input_tokens_24h / NULLIF(b.total_calls_24h, 0), 0) AS token_per_request_avg, b.duration_p50_ms, b.hourly_entropy, b.total_calls_24h FROM behavior b ), scored AS ( SELECT f.account_id, LEAST(1.0, (0.25 * LEAST(f.calls_per_hour_24h / 50.0, 1.0)) + (0.20 * LEAST(GREATEST(1.0 - f.stream_ratio_24h, 0.0), 1.0)) + (0.15 * LEAST(f.token_per_request_avg / 50000.0, 1.0)) + (0.20 * LEAST(f.duration_p50_ms / 30000.0, 1.0)) + (0.20 * LEAST(f.total_calls_24h / 500.0, 1.0)) ) AS risk_score, jsonb_build_object( 'calls_per_hour_24h', ROUND(f.calls_per_hour_24h::numeric, 6), 'stream_ratio_24h', ROUND(f.stream_ratio_24h::numeric, 6), 'token_per_request_avg', ROUND(f.token_per_request_avg::numeric, 6), 'duration_p50_ms', ROUND(f.duration_p50_ms::numeric, 6), 'hourly_entropy', ROUND(f.hourly_entropy::numeric, 6), 'total_calls_24h', ROUND(f.total_calls_24h::numeric, 6) ) AS feature_vector, jsonb_build_object( 'auto', to_jsonb(array_remove(ARRAY[ CASE WHEN f.calls_per_hour_24h >= 50 THEN 'high_calls_per_hour' END, CASE WHEN f.stream_ratio_24h <= 0.20 THEN 'low_stream_ratio' END, CASE WHEN f.token_per_request_avg >= 50000 THEN 'high_token_per_request' END, CASE WHEN f.duration_p50_ms >= 30000 THEN 'high_latency_p50' END, CASE WHEN f.total_calls_24h >= 500 THEN 'high_volume_24h' END ], NULL)) ) AS risk_reasons FROM features f ) INSERT INTO account_risk_scores ( account_id, risk_score, risk_level, risk_reasons, feature_vector, scored_at, model_version, idle_override, created_at, updated_at ) SELECT s.account_id, s.risk_score, CASE WHEN s.risk_score >= $3 THEN 'HIGH' WHEN s.risk_score >= $2 THEN 'MEDIUM' ELSE 'LOW' END, s.risk_reasons, s.feature_vector, NOW(), 1, FALSE, NOW(), NOW() FROM scored s ON CONFLICT (account_id) DO UPDATE SET risk_score = EXCLUDED.risk_score, risk_level = CASE WHEN account_risk_scores.idle_override THEN account_risk_scores.risk_level ELSE EXCLUDED.risk_level END, risk_reasons = CASE WHEN account_risk_scores.idle_override THEN COALESCE(EXCLUDED.risk_reasons, '{}'::jsonb) || CASE WHEN account_risk_scores.risk_reasons ? 'manual_override' THEN jsonb_build_object('manual_override', account_risk_scores.risk_reasons -> 'manual_override') ELSE '{}'::jsonb END ELSE EXCLUDED.risk_reasons END, feature_vector = EXCLUDED.feature_vector, scored_at = EXCLUDED.scored_at, model_version = EXCLUDED.model_version, idle_override = account_risk_scores.idle_override, updated_at = NOW() RETURNING id, account_id, risk_score, risk_level, risk_reasons, feature_vector, scored_at, model_version, idle_override, created_at, updated_at ` func (r *pgRiskRepository) UpsertBehaviorHour(ctx context.Context, accountID int64, hour time.Time, delta RiskBehaviorHourDelta) error { hour = hour.UTC().Truncate(time.Hour) _, err := r.db.ExecContext( ctx, riskUpsertBehaviorHourSQL, accountID, hour, delta.APICallCount, delta.StreamCount, delta.TotalInputTokens, delta.TotalOutputTokens, delta.TotalDurationMs, riskNullableInt(delta.P50DurationMs), ) return err } func (r *pgRiskRepository) GetRiskSummary(ctx context.Context) (*RiskSummary, error) { var ( totalAccounts int64 lowCount int64 mediumCount int64 highCount int64 averageScore float64 lastScoredAt sql.NullTime ) if err := r.db.QueryRowContext(ctx, riskSummarySQL).Scan( &totalAccounts, &lowCount, &mediumCount, &highCount, &averageScore, &lastScoredAt, ); err != nil { return nil, err } settings, err := loadRiskSettings(ctx, r.settingRepo, r.redis) if err != nil { settings = DefaultRiskSettings() } summary := &RiskSummary{ TotalAccounts: totalAccounts, LowCount: lowCount, MediumCount: mediumCount, HighCount: highCount, AverageScore: averageScore, Settings: settings, } if lastScoredAt.Valid { ts := lastScoredAt.Time summary.LastScoredAt = &ts } return summary, nil } func (r *pgRiskRepository) ListRiskAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) { level := strings.ToUpper(strings.TrimSpace(filter.Level)) platform := strings.TrimSpace(filter.Platform) limit := filter.PageSize if limit <= 0 { limit = 20 } page := filter.Page if page <= 0 { page = 1 } offset := (page - 1) * limit rows, err := r.db.QueryContext(ctx, riskListSQL, level, platform, limit, offset) if err != nil { return nil, err } defer func() { _ = rows.Close() }() result := &RiskAccountList{ Items: make([]*RiskAccountListItem, 0, limit), Page: page, PageSize: limit, } for rows.Next() { var ( totalCount int64 accountName string platformName string riskReasons []byte featureVector []byte item RiskAccountListItem ) if err := rows.Scan( &totalCount, &item.AccountID, &accountName, &platformName, &item.RiskScore, &item.RiskLevel, &riskReasons, &featureVector, &item.IdleOverride, &item.ScoredAt, &item.LastHourCalls, &item.LastHourTokens, ); err != nil { return nil, err } item.AccountName = accountName item.Platform = platformName item.RiskReasons = riskDecodeJSON(riskReasons) item.FeatureVector = riskDecodeJSON(featureVector) result.Total = totalCount result.Items = append(result.Items, &item) } if err := rows.Err(); err != nil { return nil, err } return result, nil } func (r *pgRiskRepository) GetRiskAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) { if accountID <= 0 { return nil, ErrRiskAccountNotFound } if _, err := r.GetOrCreateRiskScore(ctx, accountID); err != nil { return nil, err } rows, err := r.db.QueryContext(ctx, riskDetailSQL, accountID) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var detail *RiskAccountDetail for rows.Next() { var ( accountName string platformName string riskReasons []byte featureVector []byte hourBucket sql.NullTime apiCallCount sql.NullInt64 streamCount sql.NullInt64 totalInputTokens sql.NullInt64 totalOutputTokens sql.NullInt64 totalDurationMs sql.NullInt64 p50DurationMs sql.NullInt64 ) if detail == nil { detail = &RiskAccountDetail{HourlyBehavior: make([]RiskBehaviorHour, 0, 24)} } if err := rows.Scan( &detail.AccountID, &accountName, &platformName, &detail.RiskScore, &detail.RiskLevel, &riskReasons, &featureVector, &detail.IdleOverride, &detail.ScoredAt, &detail.ModelVersion, &hourBucket, &apiCallCount, &streamCount, &totalInputTokens, &totalOutputTokens, &totalDurationMs, &p50DurationMs, ); err != nil { return nil, err } detail.AccountName = accountName detail.Platform = platformName detail.RiskReasons = riskDecodeJSON(riskReasons) detail.FeatureVector = riskDecodeJSON(featureVector) if hourBucket.Valid { var p50 *int if p50DurationMs.Valid { v := int(p50DurationMs.Int64) p50 = &v } detail.HourlyBehavior = append(detail.HourlyBehavior, RiskBehaviorHour{ HourBucket: hourBucket.Time, APICallCount: apiCallCount.Int64, StreamCount: streamCount.Int64, TotalInputTokens: totalInputTokens.Int64, TotalOutputTokens: totalOutputTokens.Int64, TotalDurationMs: totalDurationMs.Int64, P50DurationMs: p50, }) } } if err := rows.Err(); err != nil { return nil, err } if detail == nil { return nil, ErrRiskAccountNotFound } return detail, nil } func (r *pgRiskRepository) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error { level = strings.ToUpper(strings.TrimSpace(level)) switch level { case RiskLevelLow, RiskLevelMedium, RiskLevelHigh: default: return ErrRiskLevelInvalid } if _, err := r.GetOrCreateRiskScore(ctx, accountID); err != nil { return err } var updatedAt time.Time err := r.db.QueryRowContext(ctx, riskOverrideSQL, accountID, level, strings.TrimSpace(reason)).Scan(&updatedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrRiskAccountNotFound } return err } _ = updatedAt return nil } func (r *pgRiskRepository) GetOrCreateRiskScore(ctx context.Context, accountID int64) (*RiskScoreRecord, error) { if accountID <= 0 { return nil, ErrRiskAccountNotFound } settings, err := loadRiskSettings(ctx, r.settingRepo, r.redis) if err != nil { settings = DefaultRiskSettings() } record := &RiskScoreRecord{} var riskReasons, featureVector []byte err = r.db.QueryRowContext( ctx, riskScoreRefreshSQL, accountID, settings.MediumThreshold, settings.HighThreshold, ).Scan( &record.ID, &record.AccountID, &record.RiskScore, &record.RiskLevel, &riskReasons, &featureVector, &record.ScoredAt, &record.ModelVersion, &record.IdleOverride, &record.CreatedAt, &record.UpdatedAt, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrRiskAccountNotFound } return nil, err } record.RiskReasons = riskDecodeJSON(riskReasons) record.FeatureVector = riskDecodeJSON(featureVector) return record, nil } func riskNullableInt(v *int) any { if v == nil { return nil } return *v } func riskDecodeJSON(raw []byte) json.RawMessage { if len(raw) == 0 { return json.RawMessage(`{}`) } return json.RawMessage(raw) }