sub2api/backend/internal/service/risk_repository.go
win f25dd04e0b
Some checks failed
CI / test (push) Failing after 1m31s
CI / golangci-lint (push) Failing after 3s
Security Scan / backend-security (push) Failing after 3s
Security Scan / frontend-security (push) Failing after 2s
feat(risk): 风控数据管道与风控中心
- DB Migration 081: 新增 account_behavior_hourly / account_risk_scores 表
- 行为采集:Gateway/OpenAI Gateway RecordUsage 注入 fire-and-forget CollectBehaviorAsync
- SQL 打分引擎:CTE 加权特征向量 → risk_score [0-1],UPSERT 保留 idle_override
- RiskSettings:Redis 缓存 → DB fallback → 默认值(observe 模式)
- REST API:/admin/risk/summary|accounts|accounts/:id|settings
- 前端:Pinia store + RiskControlView + 6 子组件(donut/radar/line 纯 SVG 图表)
- 侧边栏新增 Risk Control 入口(ShieldExclamationIcon)
- 反风控优化:移除 Antigravity 后台定时刷新,改为按需刷新避免 idle 封号
2026-03-28 03:07:17 +08:00

495 lines
16 KiB
Go

package service
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/redis/go-redis/v9"
)
var (
ErrRiskAccountNotFound = infraerrors.NotFound("RISK_ACCOUNT_NOT_FOUND", "risk account not found")
ErrRiskLevelInvalid = infraerrors.BadRequest("RISK_LEVEL_INVALID", "risk level must be LOW, MEDIUM, or HIGH")
)
type RiskRepository interface {
UpsertBehaviorHour(ctx context.Context, accountID int64, hour time.Time, delta RiskBehaviorHourDelta) error
GetRiskSummary(ctx context.Context) (*RiskSummary, error)
ListRiskAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error)
GetRiskAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error)
OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error
GetOrCreateRiskScore(ctx context.Context, accountID int64) (*RiskScoreRecord, error)
}
type pgRiskRepository struct {
db *sql.DB
settingRepo SettingRepository
redis *redis.Client
}
func NewRiskRepository(db *sql.DB, settingRepo SettingRepository, redisClient *redis.Client) RiskRepository {
return &pgRiskRepository{
db: db,
settingRepo: settingRepo,
redis: redisClient,
}
}
const riskUpsertBehaviorHourSQL = `
INSERT INTO account_behavior_hourly (
account_id, hour_bucket, api_call_count, stream_count,
total_input_tokens, total_output_tokens, total_duration_ms, p50_duration_ms,
created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, NOW(), NOW()
)
ON CONFLICT (account_id, hour_bucket) DO UPDATE SET
api_call_count = account_behavior_hourly.api_call_count + EXCLUDED.api_call_count,
stream_count = account_behavior_hourly.stream_count + EXCLUDED.stream_count,
total_input_tokens = account_behavior_hourly.total_input_tokens + EXCLUDED.total_input_tokens,
total_output_tokens = account_behavior_hourly.total_output_tokens + EXCLUDED.total_output_tokens,
total_duration_ms = account_behavior_hourly.total_duration_ms + EXCLUDED.total_duration_ms,
p50_duration_ms = COALESCE(EXCLUDED.p50_duration_ms, account_behavior_hourly.p50_duration_ms),
updated_at = NOW()
`
const riskSummarySQL = `
SELECT
COUNT(rs.account_id)::bigint AS total_accounts,
COUNT(*) FILTER (WHERE rs.risk_level = 'LOW')::bigint AS low_count,
COUNT(*) FILTER (WHERE rs.risk_level = 'MEDIUM')::bigint AS medium_count,
COUNT(*) FILTER (WHERE rs.risk_level = 'HIGH')::bigint AS high_count,
COALESCE(AVG(rs.risk_score), 0)::double precision AS average_score,
MAX(rs.scored_at) AS last_scored_at
FROM account_risk_scores rs
JOIN accounts a ON a.id = rs.account_id
WHERE a.deleted_at IS NULL
AND a.type IN ('oauth', 'setup_token')
`
const riskListSQL = `
WITH current_hour AS (
SELECT account_id, api_call_count,
total_input_tokens + total_output_tokens AS total_tokens
FROM account_behavior_hourly
WHERE hour_bucket = date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC'
)
SELECT
COUNT(*) OVER()::bigint AS total_count,
rs.account_id,
a.name,
a.platform,
rs.risk_score,
rs.risk_level,
rs.risk_reasons,
rs.feature_vector,
rs.idle_override,
rs.scored_at,
COALESCE(ch.api_call_count, 0)::bigint AS last_hour_calls,
COALESCE(ch.total_tokens, 0)::bigint AS last_hour_tokens
FROM account_risk_scores rs
JOIN accounts a ON a.id = rs.account_id
LEFT JOIN current_hour ch ON ch.account_id = rs.account_id
WHERE a.deleted_at IS NULL
AND a.type IN ('oauth', 'setup_token')
AND ($1 = '' OR rs.risk_level = $1)
AND ($2 = '' OR a.platform = $2)
ORDER BY rs.risk_score DESC, rs.scored_at DESC, rs.account_id DESC
LIMIT $3 OFFSET $4
`
const riskDetailSQL = `
SELECT
rs.account_id,
a.name,
a.platform,
rs.risk_score,
rs.risk_level,
rs.risk_reasons,
rs.feature_vector,
rs.idle_override,
rs.scored_at,
rs.model_version,
bh.hour_bucket,
bh.api_call_count,
bh.stream_count,
bh.total_input_tokens,
bh.total_output_tokens,
bh.total_duration_ms,
bh.p50_duration_ms
FROM account_risk_scores rs
JOIN accounts a ON a.id = rs.account_id
LEFT JOIN account_behavior_hourly bh
ON bh.account_id = rs.account_id
AND bh.hour_bucket >= (date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC') - INTERVAL '24 hours'
WHERE a.deleted_at IS NULL
AND a.type IN ('oauth', 'setup_token')
AND rs.account_id = $1
ORDER BY bh.hour_bucket DESC NULLS LAST
`
const riskOverrideSQL = `
UPDATE account_risk_scores
SET
risk_level = $2,
idle_override = TRUE,
risk_reasons = COALESCE(risk_reasons, '{}'::jsonb) || jsonb_build_object(
'manual_override', jsonb_build_object('level', $2, 'reason', $3, 'at', NOW())
),
updated_at = NOW()
WHERE account_id = $1
RETURNING updated_at
`
const riskScoreRefreshSQL = `
WITH valid_account AS (
SELECT id FROM accounts
WHERE id = $1 AND deleted_at IS NULL AND type IN ('oauth', 'setup_token')
),
behavior AS (
SELECT
va.id AS account_id,
COALESCE(SUM(abh.api_call_count), 0)::double precision AS total_calls_24h,
COALESCE(AVG(abh.api_call_count), 0)::double precision AS calls_per_hour_24h,
COALESCE(SUM(abh.stream_count), 0)::double precision AS stream_calls_24h,
COALESCE(SUM(abh.total_input_tokens), 0)::double precision AS total_input_tokens_24h,
COALESCE(
percentile_cont(0.50) WITHIN GROUP (ORDER BY abh.p50_duration_ms)
FILTER (WHERE abh.p50_duration_ms IS NOT NULL),
0
)::double precision AS duration_p50_ms,
COALESCE(stddev_pop(abh.api_call_count), 0)::double precision AS hourly_entropy
FROM valid_account va
LEFT JOIN account_behavior_hourly abh
ON abh.account_id = va.id
AND abh.hour_bucket >= (date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC') - INTERVAL '24 hours'
GROUP BY va.id
),
features AS (
SELECT
b.account_id,
b.calls_per_hour_24h,
COALESCE(b.stream_calls_24h / NULLIF(b.total_calls_24h, 0), 0) AS stream_ratio_24h,
COALESCE(b.total_input_tokens_24h / NULLIF(b.total_calls_24h, 0), 0) AS token_per_request_avg,
b.duration_p50_ms,
b.hourly_entropy,
b.total_calls_24h
FROM behavior b
),
scored AS (
SELECT
f.account_id,
LEAST(1.0,
(0.25 * LEAST(f.calls_per_hour_24h / 50.0, 1.0)) +
(0.20 * LEAST(GREATEST(1.0 - f.stream_ratio_24h, 0.0), 1.0)) +
(0.15 * LEAST(f.token_per_request_avg / 50000.0, 1.0)) +
(0.20 * LEAST(f.duration_p50_ms / 30000.0, 1.0)) +
(0.20 * LEAST(f.total_calls_24h / 500.0, 1.0))
) AS risk_score,
jsonb_build_object(
'calls_per_hour_24h', ROUND(f.calls_per_hour_24h::numeric, 6),
'stream_ratio_24h', ROUND(f.stream_ratio_24h::numeric, 6),
'token_per_request_avg', ROUND(f.token_per_request_avg::numeric, 6),
'duration_p50_ms', ROUND(f.duration_p50_ms::numeric, 6),
'hourly_entropy', ROUND(f.hourly_entropy::numeric, 6),
'total_calls_24h', ROUND(f.total_calls_24h::numeric, 6)
) AS feature_vector,
jsonb_build_object(
'auto', to_jsonb(array_remove(ARRAY[
CASE WHEN f.calls_per_hour_24h >= 50 THEN 'high_calls_per_hour' END,
CASE WHEN f.stream_ratio_24h <= 0.20 THEN 'low_stream_ratio' END,
CASE WHEN f.token_per_request_avg >= 50000 THEN 'high_token_per_request' END,
CASE WHEN f.duration_p50_ms >= 30000 THEN 'high_latency_p50' END,
CASE WHEN f.total_calls_24h >= 500 THEN 'high_volume_24h' END
], NULL))
) AS risk_reasons
FROM features f
)
INSERT INTO account_risk_scores (
account_id, risk_score, risk_level, risk_reasons, feature_vector,
scored_at, model_version, idle_override, created_at, updated_at
)
SELECT
s.account_id,
s.risk_score,
CASE
WHEN s.risk_score >= $3 THEN 'HIGH'
WHEN s.risk_score >= $2 THEN 'MEDIUM'
ELSE 'LOW'
END,
s.risk_reasons,
s.feature_vector,
NOW(), 1, FALSE, NOW(), NOW()
FROM scored s
ON CONFLICT (account_id) DO UPDATE SET
risk_score = EXCLUDED.risk_score,
risk_level = CASE
WHEN account_risk_scores.idle_override THEN account_risk_scores.risk_level
ELSE EXCLUDED.risk_level
END,
risk_reasons = CASE
WHEN account_risk_scores.idle_override THEN
COALESCE(EXCLUDED.risk_reasons, '{}'::jsonb) ||
CASE WHEN account_risk_scores.risk_reasons ? 'manual_override'
THEN jsonb_build_object('manual_override', account_risk_scores.risk_reasons -> 'manual_override')
ELSE '{}'::jsonb
END
ELSE EXCLUDED.risk_reasons
END,
feature_vector = EXCLUDED.feature_vector,
scored_at = EXCLUDED.scored_at,
model_version = EXCLUDED.model_version,
idle_override = account_risk_scores.idle_override,
updated_at = NOW()
RETURNING id, account_id, risk_score, risk_level, risk_reasons, feature_vector,
scored_at, model_version, idle_override, created_at, updated_at
`
func (r *pgRiskRepository) UpsertBehaviorHour(ctx context.Context, accountID int64, hour time.Time, delta RiskBehaviorHourDelta) error {
hour = hour.UTC().Truncate(time.Hour)
_, err := r.db.ExecContext(
ctx, riskUpsertBehaviorHourSQL,
accountID, hour,
delta.APICallCount, delta.StreamCount,
delta.TotalInputTokens, delta.TotalOutputTokens,
delta.TotalDurationMs, riskNullableInt(delta.P50DurationMs),
)
return err
}
func (r *pgRiskRepository) GetRiskSummary(ctx context.Context) (*RiskSummary, error) {
var (
totalAccounts int64
lowCount int64
mediumCount int64
highCount int64
averageScore float64
lastScoredAt sql.NullTime
)
if err := r.db.QueryRowContext(ctx, riskSummarySQL).Scan(
&totalAccounts, &lowCount, &mediumCount, &highCount, &averageScore, &lastScoredAt,
); err != nil {
return nil, err
}
settings, err := loadRiskSettings(ctx, r.settingRepo, r.redis)
if err != nil {
settings = DefaultRiskSettings()
}
summary := &RiskSummary{
TotalAccounts: totalAccounts,
LowCount: lowCount,
MediumCount: mediumCount,
HighCount: highCount,
AverageScore: averageScore,
Settings: settings,
}
if lastScoredAt.Valid {
ts := lastScoredAt.Time
summary.LastScoredAt = &ts
}
return summary, nil
}
func (r *pgRiskRepository) ListRiskAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) {
level := strings.ToUpper(strings.TrimSpace(filter.Level))
platform := strings.TrimSpace(filter.Platform)
limit := filter.PageSize
if limit <= 0 {
limit = 20
}
page := filter.Page
if page <= 0 {
page = 1
}
offset := (page - 1) * limit
rows, err := r.db.QueryContext(ctx, riskListSQL, level, platform, limit, offset)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := &RiskAccountList{
Items: make([]*RiskAccountListItem, 0, limit),
Page: page,
PageSize: limit,
}
for rows.Next() {
var (
totalCount int64
accountName string
platformName string
riskReasons []byte
featureVector []byte
item RiskAccountListItem
)
if err := rows.Scan(
&totalCount,
&item.AccountID, &accountName, &platformName,
&item.RiskScore, &item.RiskLevel,
&riskReasons, &featureVector,
&item.IdleOverride, &item.ScoredAt,
&item.LastHourCalls, &item.LastHourTokens,
); err != nil {
return nil, err
}
item.AccountName = accountName
item.Platform = platformName
item.RiskReasons = riskDecodeJSON(riskReasons)
item.FeatureVector = riskDecodeJSON(featureVector)
result.Total = totalCount
result.Items = append(result.Items, &item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *pgRiskRepository) GetRiskAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) {
if accountID <= 0 {
return nil, ErrRiskAccountNotFound
}
if _, err := r.GetOrCreateRiskScore(ctx, accountID); err != nil {
return nil, err
}
rows, err := r.db.QueryContext(ctx, riskDetailSQL, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var detail *RiskAccountDetail
for rows.Next() {
var (
accountName string
platformName string
riskReasons []byte
featureVector []byte
hourBucket sql.NullTime
apiCallCount sql.NullInt64
streamCount sql.NullInt64
totalInputTokens sql.NullInt64
totalOutputTokens sql.NullInt64
totalDurationMs sql.NullInt64
p50DurationMs sql.NullInt64
)
if detail == nil {
detail = &RiskAccountDetail{HourlyBehavior: make([]RiskBehaviorHour, 0, 24)}
}
if err := rows.Scan(
&detail.AccountID, &accountName, &platformName,
&detail.RiskScore, &detail.RiskLevel,
&riskReasons, &featureVector,
&detail.IdleOverride, &detail.ScoredAt, &detail.ModelVersion,
&hourBucket, &apiCallCount, &streamCount,
&totalInputTokens, &totalOutputTokens, &totalDurationMs, &p50DurationMs,
); err != nil {
return nil, err
}
detail.AccountName = accountName
detail.Platform = platformName
detail.RiskReasons = riskDecodeJSON(riskReasons)
detail.FeatureVector = riskDecodeJSON(featureVector)
if hourBucket.Valid {
var p50 *int
if p50DurationMs.Valid {
v := int(p50DurationMs.Int64)
p50 = &v
}
detail.HourlyBehavior = append(detail.HourlyBehavior, RiskBehaviorHour{
HourBucket: hourBucket.Time,
APICallCount: apiCallCount.Int64,
StreamCount: streamCount.Int64,
TotalInputTokens: totalInputTokens.Int64,
TotalOutputTokens: totalOutputTokens.Int64,
TotalDurationMs: totalDurationMs.Int64,
P50DurationMs: p50,
})
}
}
if err := rows.Err(); err != nil {
return nil, err
}
if detail == nil {
return nil, ErrRiskAccountNotFound
}
return detail, nil
}
func (r *pgRiskRepository) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error {
level = strings.ToUpper(strings.TrimSpace(level))
switch level {
case RiskLevelLow, RiskLevelMedium, RiskLevelHigh:
default:
return ErrRiskLevelInvalid
}
if _, err := r.GetOrCreateRiskScore(ctx, accountID); err != nil {
return err
}
var updatedAt time.Time
err := r.db.QueryRowContext(ctx, riskOverrideSQL, accountID, level, strings.TrimSpace(reason)).Scan(&updatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrRiskAccountNotFound
}
return err
}
_ = updatedAt
return nil
}
func (r *pgRiskRepository) GetOrCreateRiskScore(ctx context.Context, accountID int64) (*RiskScoreRecord, error) {
if accountID <= 0 {
return nil, ErrRiskAccountNotFound
}
settings, err := loadRiskSettings(ctx, r.settingRepo, r.redis)
if err != nil {
settings = DefaultRiskSettings()
}
record := &RiskScoreRecord{}
var riskReasons, featureVector []byte
err = r.db.QueryRowContext(
ctx, riskScoreRefreshSQL,
accountID, settings.MediumThreshold, settings.HighThreshold,
).Scan(
&record.ID, &record.AccountID,
&record.RiskScore, &record.RiskLevel,
&riskReasons, &featureVector,
&record.ScoredAt, &record.ModelVersion, &record.IdleOverride,
&record.CreatedAt, &record.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrRiskAccountNotFound
}
return nil, err
}
record.RiskReasons = riskDecodeJSON(riskReasons)
record.FeatureVector = riskDecodeJSON(featureVector)
return record, nil
}
func riskNullableInt(v *int) any {
if v == nil {
return nil
}
return *v
}
func riskDecodeJSON(raw []byte) json.RawMessage {
if len(raw) == 0 {
return json.RawMessage(`{}`)
}
return json.RawMessage(raw)
}