package service import ( "context" "encoding/json" "fmt" "log/slog" "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/redis/go-redis/v9" ) const riskSettingsCacheTTL = 5 * time.Minute var ( ErrRiskOverrideReasonRequired = infraerrors.BadRequest("RISK_OVERRIDE_REASON_REQUIRED", "override reason is required") ErrRiskSettingsInvalid = infraerrors.BadRequest("RISK_SETTINGS_INVALID", "risk settings are invalid") ) type RiskService struct { repo RiskRepository settingRepo SettingRepository redis *redis.Client } func NewRiskService(repo RiskRepository, settingRepo SettingRepository, redisClient *redis.Client) *RiskService { return &RiskService{ repo: repo, settingRepo: settingRepo, redis: redisClient, } } func (s *RiskService) CollectBehaviorAsync(ctx context.Context, account *Account, usageLog *UsageLog) { if s == nil || s.repo == nil || account == nil || usageLog == nil { return } if !account.IsOAuth() { return } go func() { bg, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() createdAt := usageLog.CreatedAt if createdAt.IsZero() { createdAt = time.Now() } delta := RiskBehaviorHourDelta{ APICallCount: 1, StreamCount: riskBoolToInt64(usageLog.Stream), TotalInputTokens: int64(usageLog.InputTokens), TotalOutputTokens: int64(usageLog.OutputTokens), TotalDurationMs: riskIntPtrToInt64(usageLog.DurationMs), P50DurationMs: usageLog.DurationMs, } if err := s.repo.UpsertBehaviorHour(bg, usageLog.AccountID, createdAt, delta); err != nil { slog.Warn("risk behavior upsert failed", "account_id", usageLog.AccountID, "error", err) return } settings, err := loadRiskSettings(bg, s.settingRepo, s.redis) if err != nil { settings = DefaultRiskSettings() } if settings.Phase == RiskPhaseOff { return } if _, err := s.repo.GetOrCreateRiskScore(bg, usageLog.AccountID); err != nil { slog.Warn("risk score refresh failed", "account_id", usageLog.AccountID, "error", err) } }() } func (s *RiskService) GetSummary(ctx context.Context) (*RiskSummary, error) { if s == nil || s.repo == nil { return nil, fmt.Errorf("risk service not initialized") } return s.repo.GetRiskSummary(ctx) } func (s *RiskService) ListAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) { if s == nil || s.repo == nil { return nil, fmt.Errorf("risk service not initialized") } if filter.Page <= 0 { filter.Page = 1 } if filter.PageSize <= 0 { filter.PageSize = 20 } if filter.PageSize > 200 { filter.PageSize = 200 } filter.Level = strings.ToUpper(strings.TrimSpace(filter.Level)) filter.Platform = strings.TrimSpace(filter.Platform) return s.repo.ListRiskAccounts(ctx, filter) } func (s *RiskService) GetAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) { if s == nil || s.repo == nil { return nil, fmt.Errorf("risk service not initialized") } if accountID <= 0 { return nil, ErrRiskAccountNotFound } return s.repo.GetRiskAccountDetail(ctx, accountID) } func (s *RiskService) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error { if s == nil || s.repo == nil { return fmt.Errorf("risk service not initialized") } if accountID <= 0 { return ErrRiskAccountNotFound } level = strings.ToUpper(strings.TrimSpace(level)) reason = strings.TrimSpace(reason) if reason == "" { return ErrRiskOverrideReasonRequired } switch level { case RiskLevelLow, RiskLevelMedium, RiskLevelHigh: default: return ErrRiskLevelInvalid } return s.repo.OverrideRiskLevel(ctx, accountID, level, reason) } func (s *RiskService) GetSettings(ctx context.Context) (*RiskSettings, error) { if s == nil || s.settingRepo == nil { return DefaultRiskSettings(), nil } return loadRiskSettings(ctx, s.settingRepo, s.redis) } func (s *RiskService) UpdateSettings(ctx context.Context, settings *RiskSettings) (*RiskSettings, error) { if s == nil || s.settingRepo == nil { return nil, fmt.Errorf("risk service not initialized") } normalized, err := normalizeRiskSettings(settings) if err != nil { return nil, err } data, err := json.Marshal(normalized) if err != nil { return nil, err } if err := s.settingRepo.Set(ctx, SettingKeyRiskSettings, string(data)); err != nil { return nil, err } if s.redis != nil { _ = s.redis.Del(ctx, riskSettingsCacheKey).Err() } return normalized, nil } func loadRiskSettings(ctx context.Context, settingRepo SettingRepository, redisClient *redis.Client) (*RiskSettings, error) { if ctx == nil { ctx = context.Background() } if redisClient != nil { if raw, err := redisClient.Get(ctx, riskSettingsCacheKey).Result(); err == nil && strings.TrimSpace(raw) != "" { settings := DefaultRiskSettings() if err := json.Unmarshal([]byte(raw), settings); err == nil { if normalized, err := normalizeRiskSettings(settings); err == nil { return normalized, nil } } } } settings := DefaultRiskSettings() if settingRepo != nil { if raw, err := settingRepo.GetValue(ctx, SettingKeyRiskSettings); err == nil && strings.TrimSpace(raw) != "" { if unmarshalErr := json.Unmarshal([]byte(raw), settings); unmarshalErr != nil { slog.Warn("risk settings json invalid; using defaults", "error", unmarshalErr) settings = DefaultRiskSettings() } } } normalized, err := normalizeRiskSettings(settings) if err != nil { normalized = DefaultRiskSettings() } if redisClient != nil { if data, marshalErr := json.Marshal(normalized); marshalErr == nil { _ = redisClient.Set(ctx, riskSettingsCacheKey, string(data), riskSettingsCacheTTL).Err() } } return normalized, nil } func normalizeRiskSettings(settings *RiskSettings) (*RiskSettings, error) { if settings == nil { return DefaultRiskSettings(), nil } out := &RiskSettings{ MediumThreshold: settings.MediumThreshold, HighThreshold: settings.HighThreshold, Phase: strings.ToLower(strings.TrimSpace(settings.Phase)), } if out.MediumThreshold == 0 && out.HighThreshold == 0 && out.Phase == "" { return DefaultRiskSettings(), nil } if out.Phase == "" { out.Phase = RiskPhaseObserve } if out.MediumThreshold < 0 || out.MediumThreshold > 1 { return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("medium_threshold must be between 0 and 1")) } if out.HighThreshold < 0 || out.HighThreshold > 1 { return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("high_threshold must be between 0 and 1")) } if out.MediumThreshold >= out.HighThreshold { return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("medium_threshold must be less than high_threshold")) } switch out.Phase { case RiskPhaseOff, RiskPhaseObserve, RiskPhaseEnforce: default: return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("phase must be one of: %s, %s, %s", RiskPhaseOff, RiskPhaseObserve, RiskPhaseEnforce)) } return out, nil } func riskBoolToInt64(v bool) int64 { if v { return 1 } return 0 } func riskIntPtrToInt64(v *int) int64 { if v == nil { return 0 } return int64(*v) }