sub2api/backend/internal/service/ops_service.go
wucm667 c9caadb378 fix(account): address second-round review on quota auto-pause
- TopK initial filter now drops quota-paused accounts: fold the quota check
  into isAccountRequestCompatible so session-hash, TopK pool, and per-candidate
  rechecks all skip paused accounts. Previously the candidate pool was built
  without the quota check, so paused accounts could fill TopK and leave the
  scheduler returning "no available accounts" even with healthy ones available.
- Add per-account explicit disable flags auto_pause_5h_disabled /
  auto_pause_7d_disabled with toggles in EditAccountModal. Without these,
  leaving the account threshold blank silently falls back to the global default,
  so admins could not exempt a single account once a global default existed.
  Disable is per-window: an account can opt out of 5h auto-pause while still
  honoring 7d. Schedule snapshot whitelist includes the new fields, i18n EN/ZH
  updated, threshold-hint text revised to explain "blank = global default".
- Move quota auto-pause settings off the request hot path: replace the per-repo
  TTL+singleflight sync DB read with a per-SettingService stale-while-revalidate
  in-memory snapshot. Get is non-blocking (atomic.Pointer load + async refresh
  on staleness); writes via UpdateOpsAdvancedSettings push directly into the
  cache through an injected sink; wire warms the cache at startup. Adds Warm
  (sync) for tests/init and SetOpenAIQuotaAutoPauseSettings (sink target).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-29 14:32:45 +08:00

699 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"database/sql"
"encoding/json"
"errors"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
const (
opsMaxStoredErrorBodyBytes = 20 * 1024
)
// OpsService provides ingestion and query APIs for the Ops monitoring module.
type OpsService struct {
opsRepo OpsRepository
settingRepo SettingRepository
cfg *config.Config
accountRepo AccountRepository
userRepo UserRepository
// getAccountAvailability is a unit-test hook for overriding account availability lookup.
getAccountAvailability func(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error)
concurrencyService *ConcurrencyService
gatewayService *GatewayService
openAIGatewayService *OpenAIGatewayService
geminiCompatService *GeminiMessagesCompatService
antigravityGatewayService *AntigravityGatewayService
systemLogSink *OpsSystemLogSink
// cleanupReloader 由 wire 在 OpsCleanupService 构造完成后通过 SetCleanupReloader 注入。
// 解耦避免 OpsService -> OpsCleanupService 的硬依赖cleanup 也读 settings会循环
cleanupReloader CleanupReloader
// quotaAutoPauseSink 由 wire 注入(通常是 SettingService.SetOpenAIQuotaAutoPauseSettings
// UpdateOpsAdvancedSettings 写入新配置后调用,把最新的 quota auto-pause 全局默认阈值
// 立即同步到调度热路径读取的内存缓存,避免下次请求才能感知新值。
quotaAutoPauseSink func(OpsOpenAIAccountQuotaAutoPauseSettings)
}
// CleanupReloader 由 OpsCleanupService 实现。
// UpdateOpsAdvancedSettings 写入新配置后调用 Reload让 schedule/enabled 改动立刻生效。
type CleanupReloader interface {
Reload(ctx context.Context) error
}
// SetCleanupReloader 由 wire 注入 cleanup hook构造期循环依赖的解耦点
func (s *OpsService) SetCleanupReloader(r CleanupReloader) {
if s == nil {
return
}
s.cleanupReloader = r
}
// SetOpenAIQuotaAutoPauseSettingsSink 由 wire 注入,把最新的 quota auto-pause 全局默认
// 阈值 push 到调度热路径读取的内存缓存。同 SetCleanupReloader 的解耦目的:避免 OpsService
// 持有 *SettingService 引入循环依赖。
func (s *OpsService) SetOpenAIQuotaAutoPauseSettingsSink(sink func(OpsOpenAIAccountQuotaAutoPauseSettings)) {
if s == nil {
return
}
s.quotaAutoPauseSink = sink
}
func NewOpsService(
opsRepo OpsRepository,
settingRepo SettingRepository,
cfg *config.Config,
accountRepo AccountRepository,
userRepo UserRepository,
concurrencyService *ConcurrencyService,
gatewayService *GatewayService,
openAIGatewayService *OpenAIGatewayService,
geminiCompatService *GeminiMessagesCompatService,
antigravityGatewayService *AntigravityGatewayService,
systemLogSink *OpsSystemLogSink,
) *OpsService {
svc := &OpsService{
opsRepo: opsRepo,
settingRepo: settingRepo,
cfg: cfg,
accountRepo: accountRepo,
userRepo: userRepo,
concurrencyService: concurrencyService,
gatewayService: gatewayService,
openAIGatewayService: openAIGatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
systemLogSink: systemLogSink,
}
svc.applyRuntimeLogConfigOnStartup(context.Background())
return svc
}
func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error {
if s.IsMonitoringEnabled(ctx) {
return nil
}
return ErrOpsDisabled
}
func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
// Hard switch: disable ops entirely.
if s.cfg != nil && !s.cfg.Ops.Enabled {
return false
}
if s.settingRepo == nil {
return true
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
if err != nil {
// Default enabled when key is missing, and fail-open on transient errors
// (ops should never block gateway traffic).
if errors.Is(err, ErrSettingNotFound) {
return true
}
return true
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
return false
default:
return true
}
}
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput) error {
prepared, ok, err := s.prepareErrorLogInput(ctx, entry)
if err != nil {
log.Printf("[Ops] RecordError prepare failed: %v", err)
return err
}
if !ok {
return nil
}
if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil {
// Never bubble up to gateway; best-effort logging.
log.Printf("[Ops] RecordError failed: %v", err)
return err
}
return nil
}
func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error {
if len(entries) == 0 {
return nil
}
prepared := make([]*OpsInsertErrorLogInput, 0, len(entries))
for _, entry := range entries {
item, ok, err := s.prepareErrorLogInput(ctx, entry)
if err != nil {
log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err)
continue
}
if ok {
prepared = append(prepared, item)
}
}
if len(prepared) == 0 {
return nil
}
if len(prepared) == 1 {
_, err := s.opsRepo.InsertErrorLog(ctx, prepared[0])
if err != nil {
log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err)
}
return err
}
if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil {
log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err)
var firstErr error
for _, entry := range prepared {
if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil {
log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr)
if firstErr == nil {
firstErr = insertErr
}
}
}
return firstErr
}
return nil
}
func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput) (*OpsInsertErrorLogInput, bool, error) {
if entry == nil {
return nil, false, nil
}
if !s.IsMonitoringEnabled(ctx) {
return nil, false, nil
}
if s.opsRepo == nil {
return nil, false, nil
}
// Ensure timestamps are always populated.
if entry.CreatedAt.IsZero() {
entry.CreatedAt = time.Now()
}
// Ensure required fields exist (DB has NOT NULL constraints).
entry.ErrorPhase = strings.TrimSpace(entry.ErrorPhase)
entry.ErrorType = strings.TrimSpace(entry.ErrorType)
if entry.ErrorPhase == "" {
entry.ErrorPhase = "internal"
}
if entry.ErrorType == "" {
entry.ErrorType = "api_error"
}
// Sanitize + truncate error_body to avoid storing sensitive data.
if strings.TrimSpace(entry.ErrorBody) != "" {
sanitized, _ := sanitizeErrorBodyForStorage(entry.ErrorBody, opsMaxStoredErrorBodyBytes)
entry.ErrorBody = sanitized
}
// Sanitize upstream error context if provided by gateway services.
if entry.UpstreamStatusCode != nil && *entry.UpstreamStatusCode <= 0 {
entry.UpstreamStatusCode = nil
}
if entry.UpstreamErrorMessage != nil {
msg := strings.TrimSpace(*entry.UpstreamErrorMessage)
msg = sanitizeUpstreamErrorMessage(msg)
msg = truncateString(msg, 2048)
if strings.TrimSpace(msg) == "" {
entry.UpstreamErrorMessage = nil
} else {
entry.UpstreamErrorMessage = &msg
}
}
if entry.UpstreamErrorDetail != nil {
detail := strings.TrimSpace(*entry.UpstreamErrorDetail)
if detail == "" {
entry.UpstreamErrorDetail = nil
} else {
sanitized, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
if strings.TrimSpace(sanitized) == "" {
entry.UpstreamErrorDetail = nil
} else {
entry.UpstreamErrorDetail = &sanitized
}
}
}
if err := sanitizeOpsUpstreamErrors(entry); err != nil {
return nil, false, err
}
return entry, true, nil
}
func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
if entry == nil || len(entry.UpstreamErrors) == 0 {
return nil
}
const maxEvents = 32
events := entry.UpstreamErrors
if len(events) > maxEvents {
events = events[len(events)-maxEvents:]
}
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
for _, ev := range events {
if ev == nil {
continue
}
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
if out.AccountID < 0 {
out.AccountID = 0
}
if out.UpstreamStatusCode < 0 {
out.UpstreamStatusCode = 0
}
if out.AtUnixMs < 0 {
out.AtUnixMs = 0
}
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
msg = truncateString(msg, 2048)
out.Message = msg
detail := strings.TrimSpace(out.Detail)
if detail != "" {
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
out.Detail = sanitizedDetail
} else {
out.Detail = ""
}
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
}
evCopy := out
sanitized = append(sanitized, &evCopy)
}
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
entry.UpstreamErrors = nil
return nil
}
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
}
result, err := s.opsRepo.ListErrorLogs(ctx, filter)
if err != nil {
log.Printf("[Ops] GetErrorLogs failed: %v", err)
return nil, err
}
return result, nil
}
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
detail, err := s.opsRepo.GetErrorLogByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
return nil, infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
return detail, nil
}
func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
// Best-effort ensure the error exists
if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, nil)
}
func sanitizeAndTrimJSONPayload(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
bytesLen = len(raw)
if len(raw) == 0 {
return "", false, 0
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
// If it is not valid JSON, fall back to the caller's non-JSON handling.
return "", false, bytesLen
}
decoded = redactSensitiveJSON(decoded)
encoded, err := json.Marshal(decoded)
if err != nil {
return "", false, bytesLen
}
if len(encoded) <= maxBytes {
return string(encoded), false, bytesLen
}
// Trim conversation history to keep the most recent context.
if root, ok := decoded.(map[string]any); ok {
if trimmed, ok := trimConversationArrays(root, maxBytes); ok {
encoded2, err2 := json.Marshal(trimmed)
if err2 == nil && len(encoded2) <= maxBytes {
return string(encoded2), true, bytesLen
}
// Fallthrough: keep shrinking.
decoded = trimmed
}
essential := shrinkToEssentials(root)
encoded3, err3 := json.Marshal(essential)
if err3 == nil && len(encoded3) <= maxBytes {
return string(encoded3), true, bytesLen
}
}
// Last resort: keep JSON shape but drop big fields.
// This avoids downstream code that expects certain top-level keys from crashing.
if root, ok := decoded.(map[string]any); ok {
placeholder := shallowCopyMap(root)
placeholder["payload_truncated"] = true
// Replace potentially huge arrays/strings, but keep the keys present.
for _, k := range []string{"messages", "contents", "input", "prompt"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = []any{}
}
}
for _, k := range []string{"text"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = ""
}
}
encoded4, err4 := json.Marshal(placeholder)
if err4 == nil {
if len(encoded4) <= maxBytes {
return string(encoded4), true, bytesLen
}
}
}
// Final fallback: minimal valid JSON.
encoded4, err4 := json.Marshal(map[string]any{"payload_truncated": true})
if err4 != nil {
return "", true, bytesLen
}
return string(encoded4), true, bytesLen
}
func redactSensitiveJSON(v any) any {
switch t := v.(type) {
case map[string]any:
out := make(map[string]any, len(t))
for k, vv := range t {
if isSensitiveKey(k) {
out[k] = "[REDACTED]"
continue
}
out[k] = redactSensitiveJSON(vv)
}
return out
case []any:
out := make([]any, 0, len(t))
for _, vv := range t {
out = append(out, redactSensitiveJSON(vv))
}
return out
default:
return v
}
}
func isSensitiveKey(key string) bool {
k := strings.ToLower(strings.TrimSpace(key))
if k == "" {
return false
}
// Token 计数 / 预算字段不是凭据,应保留用于排错。
// 白名单保持尽量窄,避免误把真实敏感信息"反脱敏"。
switch k {
case "max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"max_tokens_to_sample",
"budget_tokens",
"prompt_tokens",
"completion_tokens",
"input_tokens",
"output_tokens",
"total_tokens",
"token_count",
"cache_creation_input_tokens",
"cache_read_input_tokens":
return false
}
// Exact matches (common credential fields).
switch k {
case "authorization",
"proxy-authorization",
"x-api-key",
"api_key",
"apikey",
"access_token",
"refresh_token",
"id_token",
"session_token",
"token",
"password",
"passwd",
"passphrase",
"secret",
"client_secret",
"private_key",
"jwt",
"signature",
"accesskeyid",
"secretaccesskey":
return true
}
// Suffix matches.
for _, suffix := range []string{
"_secret",
"_token",
"_id_token",
"_session_token",
"_password",
"_passwd",
"_passphrase",
"_key",
"secret_key",
"private_key",
} {
if strings.HasSuffix(k, suffix) {
return true
}
}
// Substring matches (conservative, but errs on the side of privacy).
for _, sub := range []string{
"secret",
"token",
"password",
"passwd",
"passphrase",
"privatekey",
"private_key",
"apikey",
"api_key",
"accesskeyid",
"secretaccesskey",
"bearer",
"cookie",
"credential",
"session",
"jwt",
"signature",
} {
if strings.Contains(k, sub) {
return true
}
}
return false
}
func trimConversationArrays(root map[string]any, maxBytes int) (map[string]any, bool) {
// Supported: anthropic/openai: messages; gemini: contents.
if out, ok := trimArrayField(root, "messages", maxBytes); ok {
return out, true
}
if out, ok := trimArrayField(root, "contents", maxBytes); ok {
return out, true
}
return root, false
}
func trimArrayField(root map[string]any, field string, maxBytes int) (map[string]any, bool) {
raw, ok := root[field]
if !ok {
return nil, false
}
arr, ok := raw.([]any)
if !ok || len(arr) == 0 {
return nil, false
}
// Keep at least the last message/content. Use binary search so we don't marshal O(n) times.
// We are dropping from the *front* of the array (oldest context first).
lo := 0
hi := len(arr) - 1 // inclusive; hi ensures at least one item remains
var best map[string]any
found := false
for lo <= hi {
mid := (lo + hi) / 2
candidateArr := arr[mid:]
if len(candidateArr) == 0 {
lo = mid + 1
continue
}
next := shallowCopyMap(root)
next[field] = candidateArr
encoded, err := json.Marshal(next)
if err != nil {
// If marshal fails, try dropping more.
lo = mid + 1
continue
}
if len(encoded) <= maxBytes {
best = next
found = true
// Try to keep more context by dropping fewer items.
hi = mid - 1
continue
}
// Need to drop more.
lo = mid + 1
}
if found {
return best, true
}
// Nothing fit (even with only one element); return the smallest slice and let the
// caller fall back to shrinkToEssentials().
next := shallowCopyMap(root)
next[field] = arr[len(arr)-1:]
return next, true
}
func shrinkToEssentials(root map[string]any) map[string]any {
out := make(map[string]any)
for _, key := range []string{
"model",
"stream",
"max_tokens",
"max_output_tokens",
"max_input_tokens",
"max_completion_tokens",
"thinking",
"temperature",
"top_p",
"top_k",
} {
if v, ok := root[key]; ok {
out[key] = v
}
}
// Keep only the last element of the conversation array.
if v, ok := root["messages"]; ok {
if arr, ok := v.([]any); ok && len(arr) > 0 {
out["messages"] = []any{arr[len(arr)-1]}
}
}
if v, ok := root["contents"]; ok {
if arr, ok := v.([]any); ok && len(arr) > 0 {
out["contents"] = []any{arr[len(arr)-1]}
}
}
return out
}
func shallowCopyMap(m map[string]any) map[string]any {
out := make(map[string]any, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, truncated bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", false
}
// Prefer JSON-safe sanitization when possible.
if out, trunc, _ := sanitizeAndTrimJSONPayload([]byte(raw), maxBytes); out != "" {
return out, trunc
}
// Non-JSON: best-effort truncate.
if maxBytes > 0 && len(raw) > maxBytes {
return truncateString(raw, maxBytes), true
}
return raw, false
}