feat(refresh,repo): add singleflight to dedupe concurrent token refresh and unschedulable writes
Two anti-thundering-herd improvements: 1. OAuthRefreshAPI.RefreshIfNeeded Wrap the existing distributed-lock + DB-reread + executor.Refresh pipeline in a per-process singleflight keyed by cacheKey+window. Without this, N concurrent goroutines on the same account each pay one Redis lock RTT and one DB reread; with it, only the leader pays and the rest share the result. The refreshWindow is part of the key so a long background-refresh window cannot starve a short foreground-refresh window. 2. accountRepository.SetTempUnschedulable Wrap the same path (UPDATE + scheduler outbox enqueue + scheduler cache sync) in a per-process singleflight keyed by id+until+reason. The SQL guard (existing < new) already makes the UPDATE idempotent, but N callers still cost N round-trips and N outbox inserts. With singleflight, an upstream 401 burst that hits the same account collapses to one execution. Tests cover dedup behavior, key separation by account / refresh window, and that the SQL exec count drops from N to <=2 (UPDATE + outbox).
This commit is contained in:
parent
110902ad4b
commit
5c8c15cdb1
@ -29,6 +29,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
@ -49,6 +50,13 @@ type accountRepository struct {
|
||||
// Used to proactively sync account snapshot to cache when status changes,
|
||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||
schedulerCache service.SchedulerCache
|
||||
|
||||
// tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。
|
||||
// 上游 401/限流爆发时,N 个 in-flight 请求会同时调用此方法;底层 SQL
|
||||
// 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row,但 N 次
|
||||
// SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight
|
||||
// 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。
|
||||
tempUnschedSF singleflight.Group
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||
@ -1029,6 +1037,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
// 进程内合并并发调用:key 包含 until 和 reason,确保不同窗口/原因独立去重。
|
||||
// until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致;
|
||||
// 哪怕略有偏差,SQL 的 (existing < new) 条件保证语义安全。
|
||||
sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason
|
||||
_, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) {
|
||||
return nil, r.setTempUnschedulableOnce(ctx, id, until, reason)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = $1,
|
||||
|
||||
119
backend/internal/repository/account_repo_singleflight_test.go
Normal file
119
backend/internal/repository/account_repo_singleflight_test.go
Normal file
@ -0,0 +1,119 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。
|
||||
// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在
|
||||
// singleflight 的同一窗口内。
|
||||
type blockingExecutor struct {
|
||||
mu sync.Mutex
|
||||
execCalls int32
|
||||
queryCalls int32
|
||||
release chan struct{}
|
||||
concurrent int32
|
||||
maxObserved int32
|
||||
}
|
||||
|
||||
func newBlockingExecutor() *blockingExecutor {
|
||||
return &blockingExecutor{release: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (e *blockingExecutor) Release() { close(e.release) }
|
||||
|
||||
func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||
atomic.AddInt32(&e.execCalls, 1)
|
||||
c := atomic.AddInt32(&e.concurrent, 1)
|
||||
for {
|
||||
old := atomic.LoadInt32(&e.maxObserved)
|
||||
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer atomic.AddInt32(&e.concurrent, -1)
|
||||
<-e.release
|
||||
return driverResult{}, nil
|
||||
}
|
||||
|
||||
func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) {
|
||||
atomic.AddInt32(&e.queryCalls, 1)
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
// driverResult 是一个零值 sql.Result,用于测试。
|
||||
type driverResult struct{}
|
||||
|
||||
func (driverResult) LastInsertId() (int64, error) { return 0, nil }
|
||||
func (driverResult) RowsAffected() (int64, error) { return 1, nil }
|
||||
|
||||
func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) {
|
||||
// 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际
|
||||
// SQL 路径(UPDATE + outbox INSERT = 2 次 ExecContext)。
|
||||
exec := newBlockingExecutor()
|
||||
repo := newAccountRepositoryWithSQL(nil, exec, nil)
|
||||
|
||||
const callers = 30
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
const reason = "OAuth 401: invalid_grant"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for i := 0; i < callers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = repo.SetTempUnschedulable(context.Background(), 42, until, reason)
|
||||
}()
|
||||
}
|
||||
|
||||
// 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent),
|
||||
"singleflight should serialize the SQL call to exactly one in-flight execution")
|
||||
|
||||
exec.Release()
|
||||
wg.Wait()
|
||||
|
||||
// 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec;其余 29 个 caller 共享结果。
|
||||
require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2),
|
||||
"expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved),
|
||||
"no two SQL execs should run concurrently for the same singleflight key")
|
||||
}
|
||||
|
||||
func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) {
|
||||
// 不同 account 应分属不同 sf key,能并行写库。
|
||||
exec := newBlockingExecutor()
|
||||
repo := newAccountRepositoryWithSQL(nil, exec, nil)
|
||||
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
var wg sync.WaitGroup
|
||||
for i := int64(1); i <= 3; i++ {
|
||||
i := i
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason")
|
||||
}()
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved),
|
||||
"different accounts should be able to write in parallel")
|
||||
|
||||
exec.Release()
|
||||
wg.Wait()
|
||||
}
|
||||
@ -6,6 +6,8 @@ import (
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||||
@ -29,9 +31,19 @@ type OAuthRefreshResult struct {
|
||||
|
||||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
|
||||
//
|
||||
// 双层去重设计:
|
||||
// 1. 进程内 singleflight:合并同一 cacheKey 的并发调用(避免 100 个 goroutine
|
||||
// 都去抢同一把分布式锁、都重读一次 DB)。
|
||||
// 2. 跨进程分布式锁(Redis):保证集群范围内只有一个 worker 真正发起 OAuth
|
||||
// 刷新请求。
|
||||
//
|
||||
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT;跨进程锁仍是必需的,
|
||||
// singleflight 解决不了多 pod 同时刷新。
|
||||
type OAuthRefreshAPI struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache // 可选,nil = 无锁
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||
@ -42,15 +54,19 @@ func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCac
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
|
||||
//
|
||||
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
|
||||
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
|
||||
//
|
||||
// 流程:
|
||||
// 1. 获取分布式锁
|
||||
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||
// 3. 二次检查是否仍需刷新
|
||||
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||
// 5. 设置 _token_version + 更新 DB
|
||||
// 6. 释放锁
|
||||
// 1. singleflight 合并同 cacheKey 并发调用
|
||||
// 2. 获取分布式锁(跨进程)
|
||||
// 3. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||
// 4. 二次检查是否仍需刷新
|
||||
// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||
// 6. 设置 _token_version + 更新 DB
|
||||
// 7. 释放锁
|
||||
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
@ -59,6 +75,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
) (*OAuthRefreshResult, error) {
|
||||
cacheKey := executor.CacheKey(account)
|
||||
|
||||
// singleflight key 同时区分 cacheKey 和 refreshWindow:
|
||||
// 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh,
|
||||
// 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
|
||||
sfKey := cacheKey + "|" + refreshWindow.String()
|
||||
|
||||
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
|
||||
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, _ := v.(*OAuthRefreshResult)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
|
||||
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
|
||||
func (api *OAuthRefreshAPI) refreshOnce(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
executor OAuthRefreshExecutor,
|
||||
refreshWindow time.Duration,
|
||||
cacheKey string,
|
||||
) (*OAuthRefreshResult, error) {
|
||||
// 1. 获取分布式锁
|
||||
lockAcquired := false
|
||||
if api.tokenCache != nil {
|
||||
|
||||
160
backend/internal/service/oauth_refresh_api_singleflight_test.go
Normal file
160
backend/internal/service/oauth_refresh_api_singleflight_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。
|
||||
type blockingExecutor struct {
|
||||
refreshAPIExecutorStub
|
||||
release chan struct{}
|
||||
concurrent int32 // 当前正在 Refresh 的 goroutine 数
|
||||
maxObserved int32 // 观察到的最大并发数
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
c := atomic.AddInt32(&e.concurrent, 1)
|
||||
for {
|
||||
old := atomic.LoadInt32(&e.maxObserved)
|
||||
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer atomic.AddInt32(&e.concurrent, -1)
|
||||
|
||||
<-e.release
|
||||
return e.credentials, e.err
|
||||
}
|
||||
|
||||
func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) {
|
||||
// 同一 cacheKey 同时进入 N 个 goroutine,应只触发 1 次 executor.Refresh。
|
||||
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
|
||||
exec := &blockingExecutor{
|
||||
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new"},
|
||||
},
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
|
||||
const callers = 20
|
||||
results := make([]*OAuthRefreshResult, callers)
|
||||
errs := make([]error, callers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
|
||||
for i := 0; i < callers; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
|
||||
results[i] = r
|
||||
errs[i] = err
|
||||
}()
|
||||
}
|
||||
|
||||
// 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh")
|
||||
|
||||
close(exec.release)
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once")
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously")
|
||||
|
||||
// 所有 caller 应拿到等价结果(不必同实例,singleflight Shared 标志会让多个 caller 共享)。
|
||||
for i := 0; i < callers; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.NotNil(t, results[i])
|
||||
require.True(t, results[i].Refreshed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) {
|
||||
// 不同账号有不同 cacheKey,应能并行刷新而非互相阻塞。
|
||||
repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
|
||||
exec := &blockingExecutor{
|
||||
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new"},
|
||||
},
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
platform := "p" + string(rune('a'+i))
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute)
|
||||
}()
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel")
|
||||
|
||||
close(exec.release)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) {
|
||||
// 同 cacheKey 但不同 refreshWindow(前台短窗口 vs 后台长窗口)应分开判断
|
||||
// NeedsRefresh,避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。
|
||||
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
|
||||
exec := &blockingExecutor{
|
||||
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new"},
|
||||
},
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour)
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged")
|
||||
|
||||
close(exec.release)
|
||||
wg.Wait()
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user