sub2api/backend/internal/service/oauth_refresh_api_singleflight_test.go
win 5c8c15cdb1 feat(refresh,repo): add singleflight to dedupe concurrent token refresh and unschedulable writes
Two anti-thundering-herd improvements:

1. OAuthRefreshAPI.RefreshIfNeeded
   Wrap the existing distributed-lock + DB-reread + executor.Refresh
   pipeline in a per-process singleflight keyed by cacheKey+window.
   Without this, N concurrent goroutines on the same account each pay
   one Redis lock RTT and one DB reread; with it, only the leader pays
   and the rest share the result.

   The refreshWindow is part of the key so a long background-refresh
   window cannot starve a short foreground-refresh window.

2. accountRepository.SetTempUnschedulable
   Wrap the same path (UPDATE + scheduler outbox enqueue + scheduler
   cache sync) in a per-process singleflight keyed by id+until+reason.
   The SQL guard (existing < new) already makes the UPDATE idempotent,
   but N callers still cost N round-trips and N outbox inserts. With
   singleflight, an upstream 401 burst that hits the same account
   collapses to one execution.

Tests cover dedup behavior, key separation by account / refresh window,
and that the SQL exec count drops from N to <=2 (UPDATE + outbox).
2026-04-29 00:43:23 +08:00

161 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。
type blockingExecutor struct {
refreshAPIExecutorStub
release chan struct{}
concurrent int32 // 当前正在 Refresh 的 goroutine 数
maxObserved int32 // 观察到的最大并发数
calls int32
}
func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
atomic.AddInt32(&e.calls, 1)
c := atomic.AddInt32(&e.concurrent, 1)
for {
old := atomic.LoadInt32(&e.maxObserved)
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
break
}
}
defer atomic.AddInt32(&e.concurrent, -1)
<-e.release
return e.credentials, e.err
}
func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) {
// 同一 cacheKey 同时进入 N 个 goroutine应只触发 1 次 executor.Refresh。
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
const callers = 20
results := make([]*OAuthRefreshResult, callers)
errs := make([]error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for i := 0; i < callers; i++ {
i := i
go func() {
defer wg.Done()
r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
results[i] = r
errs[i] = err
}()
}
// 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh")
close(exec.release)
wg.Wait()
require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once")
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously")
// 所有 caller 应拿到等价结果不必同实例singleflight Shared 标志会让多个 caller 共享)。
for i := 0; i < callers; i++ {
require.NoError(t, errs[i])
require.NotNil(t, results[i])
require.True(t, results[i].Refreshed)
}
}
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) {
// 不同账号有不同 cacheKey应能并行刷新而非互相阻塞。
repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
platform := "p" + string(rune('a'+i))
wg.Add(1)
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute)
}()
}
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel")
close(exec.release)
wg.Wait()
}
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) {
// 同 cacheKey 但不同 refreshWindow前台短窗口 vs 后台长窗口)应分开判断
// NeedsRefresh避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
}()
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour)
}()
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged")
close(exec.release)
wg.Wait()
}