feat(refresh,repo): add singleflight to dedupe concurrent token refresh and unschedulable writes

Two anti-thundering-herd improvements:

1. OAuthRefreshAPI.RefreshIfNeeded
   Wrap the existing distributed-lock + DB-reread + executor.Refresh
   pipeline in a per-process singleflight keyed by cacheKey+window.
   Without this, N concurrent goroutines on the same account each pay
   one Redis lock RTT and one DB reread; with it, only the leader pays
   and the rest share the result.

   The refreshWindow is part of the key so a long background-refresh
   window cannot starve a short foreground-refresh window.

2. accountRepository.SetTempUnschedulable
   Wrap the same path (UPDATE + scheduler outbox enqueue + scheduler
   cache sync) in a per-process singleflight keyed by id+until+reason.
   The SQL guard (existing < new) already makes the UPDATE idempotent,
   but N callers still cost N round-trips and N outbox inserts. With
   singleflight, an upstream 401 burst that hits the same account
   collapses to one execution.

Tests cover dedup behavior, key separation by account / refresh window,
and that the SQL exec count drops from N to <=2 (UPDATE + outbox).
This commit is contained in:
win 2026-04-29 00:43:23 +08:00
parent 110902ad4b
commit 5c8c15cdb1
4 changed files with 345 additions and 7 deletions

View File

@ -29,6 +29,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"golang.org/x/sync/singleflight"
entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson"
@ -49,6 +50,13 @@ type accountRepository struct {
// Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache
// tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。
// 上游 401/限流爆发时N 个 in-flight 请求会同时调用此方法;底层 SQL
// 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row但 N 次
// SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight
// 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。
tempUnschedSF singleflight.Group
}
var schedulerNeutralExtraKeyPrefixes = []string{
@ -1029,6 +1037,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
}
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
// 进程内合并并发调用key 包含 until 和 reason确保不同窗口/原因独立去重。
// until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致;
// 哪怕略有偏差SQL 的 (existing < new) 条件保证语义安全。
sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason
_, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) {
return nil, r.setTempUnschedulableOnce(ctx, id, until, reason)
})
return err
}
func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = $1,

View File

@ -0,0 +1,119 @@
package repository
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。
// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在
// singleflight 的同一窗口内。
type blockingExecutor struct {
mu sync.Mutex
execCalls int32
queryCalls int32
release chan struct{}
concurrent int32
maxObserved int32
}
func newBlockingExecutor() *blockingExecutor {
return &blockingExecutor{release: make(chan struct{})}
}
func (e *blockingExecutor) Release() { close(e.release) }
func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) {
atomic.AddInt32(&e.execCalls, 1)
c := atomic.AddInt32(&e.concurrent, 1)
for {
old := atomic.LoadInt32(&e.maxObserved)
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
break
}
}
defer atomic.AddInt32(&e.concurrent, -1)
<-e.release
return driverResult{}, nil
}
func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) {
atomic.AddInt32(&e.queryCalls, 1)
return nil, sql.ErrNoRows
}
// driverResult 是一个零值 sql.Result用于测试。
type driverResult struct{}
func (driverResult) LastInsertId() (int64, error) { return 0, nil }
func (driverResult) RowsAffected() (int64, error) { return 1, nil }
func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) {
// 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际
// SQL 路径UPDATE + outbox INSERT = 2 次 ExecContext
exec := newBlockingExecutor()
repo := newAccountRepositoryWithSQL(nil, exec, nil)
const callers = 30
until := time.Now().Add(10 * time.Minute)
const reason = "OAuth 401: invalid_grant"
var wg sync.WaitGroup
wg.Add(callers)
for i := 0; i < callers; i++ {
go func() {
defer wg.Done()
_ = repo.SetTempUnschedulable(context.Background(), 42, until, reason)
}()
}
// 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
time.Sleep(5 * time.Millisecond)
}
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent),
"singleflight should serialize the SQL call to exactly one in-flight execution")
exec.Release()
wg.Wait()
// 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec其余 29 个 caller 共享结果。
require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2),
"expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls)
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved),
"no two SQL execs should run concurrently for the same singleflight key")
}
func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) {
// 不同 account 应分属不同 sf key能并行写库。
exec := newBlockingExecutor()
repo := newAccountRepositoryWithSQL(nil, exec, nil)
until := time.Now().Add(10 * time.Minute)
var wg sync.WaitGroup
for i := int64(1); i <= 3; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
_ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason")
}()
}
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
time.Sleep(5 * time.Millisecond)
}
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved),
"different accounts should be able to write in parallel")
exec.Release()
wg.Wait()
}

View File

@ -6,6 +6,8 @@ import (
"log/slog"
"strconv"
"time"
"golang.org/x/sync/singleflight"
)
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
@ -29,9 +31,19 @@ type OAuthRefreshResult struct {
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
//
// 双层去重设计:
// 1. 进程内 singleflight合并同一 cacheKey 的并发调用(避免 100 个 goroutine
// 都去抢同一把分布式锁、都重读一次 DB
// 2. 跨进程分布式锁Redis保证集群范围内只有一个 worker 真正发起 OAuth
// 刷新请求。
//
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT跨进程锁仍是必需的
// singleflight 解决不了多 pod 同时刷新。
type OAuthRefreshAPI struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache // 可选nil = 无锁
sf singleflight.Group
}
// NewOAuthRefreshAPI 创建统一刷新 API
@ -42,15 +54,19 @@ func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCac
}
}
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
//
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
//
// 流程:
// 1. 获取分布式锁
// 2. 从 DB 重读最新 account防止使用过时的 refresh_token
// 3. 二次检查是否仍需刷新
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 5. 设置 _token_version + 更新 DB
// 6. 释放锁
// 1. singleflight 合并同 cacheKey 并发调用
// 2. 获取分布式锁(跨进程)
// 3. 从 DB 重读最新 account防止使用过时的 refresh_token
// 4. 二次检查是否仍需刷新
// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 6. 设置 _token_version + 更新 DB
// 7. 释放锁
func (api *OAuthRefreshAPI) RefreshIfNeeded(
ctx context.Context,
account *Account,
@ -59,6 +75,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
) (*OAuthRefreshResult, error) {
cacheKey := executor.CacheKey(account)
// singleflight key 同时区分 cacheKey 和 refreshWindow
// 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh
// 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
sfKey := cacheKey + "|" + refreshWindow.String()
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
})
if err != nil {
return nil, err
}
result, _ := v.(*OAuthRefreshResult)
return result, nil
}
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
func (api *OAuthRefreshAPI) refreshOnce(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
cacheKey string,
) (*OAuthRefreshResult, error) {
// 1. 获取分布式锁
lockAcquired := false
if api.tokenCache != nil {

View File

@ -0,0 +1,160 @@
//go:build unit
package service
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。
type blockingExecutor struct {
refreshAPIExecutorStub
release chan struct{}
concurrent int32 // 当前正在 Refresh 的 goroutine 数
maxObserved int32 // 观察到的最大并发数
calls int32
}
func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
atomic.AddInt32(&e.calls, 1)
c := atomic.AddInt32(&e.concurrent, 1)
for {
old := atomic.LoadInt32(&e.maxObserved)
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
break
}
}
defer atomic.AddInt32(&e.concurrent, -1)
<-e.release
return e.credentials, e.err
}
func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) {
// 同一 cacheKey 同时进入 N 个 goroutine应只触发 1 次 executor.Refresh。
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
const callers = 20
results := make([]*OAuthRefreshResult, callers)
errs := make([]error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for i := 0; i < callers; i++ {
i := i
go func() {
defer wg.Done()
r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
results[i] = r
errs[i] = err
}()
}
// 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh")
close(exec.release)
wg.Wait()
require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once")
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously")
// 所有 caller 应拿到等价结果不必同实例singleflight Shared 标志会让多个 caller 共享)。
for i := 0; i < callers; i++ {
require.NoError(t, errs[i])
require.NotNil(t, results[i])
require.True(t, results[i].Refreshed)
}
}
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) {
// 不同账号有不同 cacheKey应能并行刷新而非互相阻塞。
repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
platform := "p" + string(rune('a'+i))
wg.Add(1)
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute)
}()
}
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel")
close(exec.release)
wg.Wait()
}
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) {
// 同 cacheKey 但不同 refreshWindow前台短窗口 vs 后台长窗口)应分开判断
// NeedsRefresh避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
}()
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour)
}()
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged")
close(exec.release)
wg.Wait()
}