diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a12..23db17d1 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -29,6 +29,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + "golang.org/x/sync/singleflight" entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqljson" @@ -49,6 +50,13 @@ type accountRepository struct { // Used to proactively sync account snapshot to cache when status changes, // ensuring sticky sessions can promptly detect unavailable accounts. schedulerCache service.SchedulerCache + + // tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。 + // 上游 401/限流爆发时,N 个 in-flight 请求会同时调用此方法;底层 SQL + // 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row,但 N 次 + // SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight + // 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。 + tempUnschedSF singleflight.Group } var schedulerNeutralExtraKeyPrefixes = []string{ @@ -1029,6 +1037,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t } func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + // 进程内合并并发调用:key 包含 until 和 reason,确保不同窗口/原因独立去重。 + // until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致; + // 哪怕略有偏差,SQL 的 (existing < new) 条件保证语义安全。 + sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason + _, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) { + return nil, r.setTempUnschedulableOnce(ctx, id, until, reason) + }) + return err +} + +func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error { _, err := r.sql.ExecContext(ctx, ` UPDATE accounts SET temp_unschedulable_until = $1, diff --git a/backend/internal/repository/account_repo_singleflight_test.go b/backend/internal/repository/account_repo_singleflight_test.go new file mode 100644 index 00000000..5d5d1201 --- /dev/null +++ b/backend/internal/repository/account_repo_singleflight_test.go @@ -0,0 +1,119 @@ +package repository + +import ( + "context" + "database/sql" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。 +// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在 +// singleflight 的同一窗口内。 +type blockingExecutor struct { + mu sync.Mutex + execCalls int32 + queryCalls int32 + release chan struct{} + concurrent int32 + maxObserved int32 +} + +func newBlockingExecutor() *blockingExecutor { + return &blockingExecutor{release: make(chan struct{})} +} + +func (e *blockingExecutor) Release() { close(e.release) } + +func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) { + atomic.AddInt32(&e.execCalls, 1) + c := atomic.AddInt32(&e.concurrent, 1) + for { + old := atomic.LoadInt32(&e.maxObserved) + if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) { + break + } + } + defer atomic.AddInt32(&e.concurrent, -1) + <-e.release + return driverResult{}, nil +} + +func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { + atomic.AddInt32(&e.queryCalls, 1) + return nil, sql.ErrNoRows +} + +// driverResult 是一个零值 sql.Result,用于测试。 +type driverResult struct{} + +func (driverResult) LastInsertId() (int64, error) { return 0, nil } +func (driverResult) RowsAffected() (int64, error) { return 1, nil } + +func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) { + // 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际 + // SQL 路径(UPDATE + outbox INSERT = 2 次 ExecContext)。 + exec := newBlockingExecutor() + repo := newAccountRepositoryWithSQL(nil, exec, nil) + + const callers = 30 + until := time.Now().Add(10 * time.Minute) + const reason = "OAuth 401: invalid_grant" + + var wg sync.WaitGroup + wg.Add(callers) + for i := 0; i < callers; i++ { + go func() { + defer wg.Done() + _ = repo.SetTempUnschedulable(context.Background(), 42, until, reason) + }() + } + + // 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。 + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) { + time.Sleep(5 * time.Millisecond) + } + require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), + "singleflight should serialize the SQL call to exactly one in-flight execution") + + exec.Release() + wg.Wait() + + // 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec;其余 29 个 caller 共享结果。 + require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2), + "expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls) + require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), + "no two SQL execs should run concurrently for the same singleflight key") +} + +func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) { + // 不同 account 应分属不同 sf key,能并行写库。 + exec := newBlockingExecutor() + repo := newAccountRepositoryWithSQL(nil, exec, nil) + + until := time.Now().Add(10 * time.Minute) + var wg sync.WaitGroup + for i := int64(1); i <= 3; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + _ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason") + }() + } + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) { + time.Sleep(5 * time.Millisecond) + } + require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), + "different accounts should be able to write in parallel") + + exec.Release() + wg.Wait() +} diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 5dbba638..545db3ea 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -6,6 +6,8 @@ import ( "log/slog" "strconv" "time" + + "golang.org/x/sync/singleflight" ) // OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 @@ -29,9 +31,19 @@ type OAuthRefreshResult struct { // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 // 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +// +// 双层去重设计: +// 1. 进程内 singleflight:合并同一 cacheKey 的并发调用(避免 100 个 goroutine +// 都去抢同一把分布式锁、都重读一次 DB)。 +// 2. 跨进程分布式锁(Redis):保证集群范围内只有一个 worker 真正发起 OAuth +// 刷新请求。 +// +// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT;跨进程锁仍是必需的, +// singleflight 解决不了多 pod 同时刷新。 type OAuthRefreshAPI struct { accountRepo AccountRepository tokenCache GeminiTokenCache // 可选,nil = 无锁 + sf singleflight.Group } // NewOAuthRefreshAPI 创建统一刷新 API @@ -42,15 +54,19 @@ func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCac } } -// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token +// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。 +// +// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者" +// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。 // // 流程: -// 1. 获取分布式锁 -// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token) -// 3. 二次检查是否仍需刷新 -// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑 -// 5. 设置 _token_version + 更新 DB -// 6. 释放锁 +// 1. singleflight 合并同 cacheKey 并发调用 +// 2. 获取分布式锁(跨进程) +// 3. 从 DB 重读最新 account(防止使用过时的 refresh_token) +// 4. 二次检查是否仍需刷新 +// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑 +// 6. 设置 _token_version + 更新 DB +// 7. 释放锁 func (api *OAuthRefreshAPI) RefreshIfNeeded( ctx context.Context, account *Account, @@ -59,6 +75,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) + // singleflight key 同时区分 cacheKey 和 refreshWindow: + // 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh, + // 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。 + sfKey := cacheKey + "|" + refreshWindow.String() + + v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) { + return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey) + }) + if err != nil { + return nil, err + } + result, _ := v.(*OAuthRefreshResult) + return result, nil +} + +// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。 +// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。 +func (api *OAuthRefreshAPI) refreshOnce( + ctx context.Context, + account *Account, + executor OAuthRefreshExecutor, + refreshWindow time.Duration, + cacheKey string, +) (*OAuthRefreshResult, error) { // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { diff --git a/backend/internal/service/oauth_refresh_api_singleflight_test.go b/backend/internal/service/oauth_refresh_api_singleflight_test.go new file mode 100644 index 00000000..3b834943 --- /dev/null +++ b/backend/internal/service/oauth_refresh_api_singleflight_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。 +type blockingExecutor struct { + refreshAPIExecutorStub + release chan struct{} + concurrent int32 // 当前正在 Refresh 的 goroutine 数 + maxObserved int32 // 观察到的最大并发数 + calls int32 +} + +func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) { + atomic.AddInt32(&e.calls, 1) + c := atomic.AddInt32(&e.concurrent, 1) + for { + old := atomic.LoadInt32(&e.maxObserved) + if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) { + break + } + } + defer atomic.AddInt32(&e.concurrent, -1) + + <-e.release + return e.credentials, e.err +} + +func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) { + // 同一 cacheKey 同时进入 N 个 goroutine,应只触发 1 次 executor.Refresh。 + repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}} + cache := &refreshAPICacheStub{lockResult: true} + + exec := &blockingExecutor{ + refreshAPIExecutorStub: refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new"}, + }, + release: make(chan struct{}), + } + + api := NewOAuthRefreshAPI(repo, cache) + + const callers = 20 + results := make([]*OAuthRefreshResult, callers) + errs := make([]error, callers) + var wg sync.WaitGroup + wg.Add(callers) + + for i := 0; i < callers; i++ { + i := i + go func() { + defer wg.Done() + r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute) + results[i] = r + errs[i] = err + }() + } + + // 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。 + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh") + + close(exec.release) + wg.Wait() + + require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once") + require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously") + + // 所有 caller 应拿到等价结果(不必同实例,singleflight Shared 标志会让多个 caller 共享)。 + for i := 0; i < callers; i++ { + require.NoError(t, errs[i]) + require.NotNil(t, results[i]) + require.True(t, results[i].Refreshed) + } +} + +func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) { + // 不同账号有不同 cacheKey,应能并行刷新而非互相阻塞。 + repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}} + cache := &refreshAPICacheStub{lockResult: true} + + exec := &blockingExecutor{ + refreshAPIExecutorStub: refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new"}, + }, + release: make(chan struct{}), + } + + api := NewOAuthRefreshAPI(repo, cache) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + platform := "p" + string(rune('a'+i)) + wg.Add(1) + go func() { + defer wg.Done() + _, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute) + }() + } + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel") + + close(exec.release) + wg.Wait() +} + +func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) { + // 同 cacheKey 但不同 refreshWindow(前台短窗口 vs 后台长窗口)应分开判断 + // NeedsRefresh,避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。 + repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}} + cache := &refreshAPICacheStub{lockResult: true} + + exec := &blockingExecutor{ + refreshAPIExecutorStub: refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new"}, + }, + release: make(chan struct{}), + } + api := NewOAuthRefreshAPI(repo, cache) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute) + }() + go func() { + defer wg.Done() + _, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour) + }() + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged") + + close(exec.release) + wg.Wait() +}