226 lines
5.5 KiB
Go
226 lines
5.5 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
|
)
|
|
|
|
const (
|
|
defaultLSPoolBootstrapConcurrency = 4
|
|
)
|
|
|
|
type lsBootstrapAccountReader interface {
|
|
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
|
}
|
|
|
|
// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup.
|
|
type LSPoolBootstrapService struct {
|
|
accountReader lsBootstrapAccountReader
|
|
backend lspool.Backend
|
|
cfg *config.Config
|
|
logger *slog.Logger
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
once sync.Once
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &LSPoolBootstrapService{
|
|
accountReader: accountReader,
|
|
backend: backend,
|
|
cfg: cfg,
|
|
logger: slog.Default().With("component", "service.lspool_bootstrap"),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker.
|
|
func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService {
|
|
svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg)
|
|
svc.Start()
|
|
return svc
|
|
}
|
|
|
|
func (s *LSPoolBootstrapService) Start() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.once.Do(func() {
|
|
if s.backend == nil {
|
|
if lspool.IsLSModeEnabled() {
|
|
s.logger.Warn("startup bootstrap skipped: ls backend unavailable")
|
|
}
|
|
return
|
|
}
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
s.bootstrap(s.ctx)
|
|
}()
|
|
})
|
|
}
|
|
|
|
func (s *LSPoolBootstrapService) Stop() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.cancel()
|
|
s.wg.Wait()
|
|
}
|
|
|
|
func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) {
|
|
if s.backend == nil || s.accountReader == nil {
|
|
return
|
|
}
|
|
|
|
accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity)
|
|
if err != nil {
|
|
s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err)
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
candidates := make([]Account, 0, len(accounts))
|
|
for i := range accounts {
|
|
if shouldBootstrapLSPoolAccount(&accounts[i], now) {
|
|
candidates = append(candidates, accounts[i])
|
|
}
|
|
}
|
|
|
|
if len(candidates) == 0 {
|
|
s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts")
|
|
return
|
|
}
|
|
|
|
s.logger.Info("starting ls worker bootstrap",
|
|
"accounts_total", len(accounts),
|
|
"accounts_eligible", len(candidates),
|
|
"concurrency", s.bootstrapConcurrency())
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
started int
|
|
failed int
|
|
)
|
|
sem := make(chan struct{}, s.bootstrapConcurrency())
|
|
var wg sync.WaitGroup
|
|
|
|
loop:
|
|
for i := range candidates {
|
|
account := candidates[i]
|
|
select {
|
|
case <-ctx.Done():
|
|
break loop
|
|
case sem <- struct{}{}:
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(account Account) {
|
|
defer wg.Done()
|
|
defer func() { <-sem }()
|
|
|
|
if err := s.bootstrapAccount(&account); err != nil {
|
|
mu.Lock()
|
|
failed++
|
|
mu.Unlock()
|
|
s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err)
|
|
return
|
|
}
|
|
|
|
mu.Lock()
|
|
started++
|
|
mu.Unlock()
|
|
s.logger.Info("bootstrap ls worker ready", "account_id", account.ID)
|
|
}(account)
|
|
}
|
|
|
|
wg.Wait()
|
|
s.logger.Info("ls worker bootstrap completed",
|
|
"accounts_total", len(accounts),
|
|
"accounts_eligible", len(candidates),
|
|
"workers_ready", started,
|
|
"workers_failed", failed,
|
|
"canceled", ctx.Err() != nil)
|
|
}
|
|
|
|
func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error {
|
|
if s.backend == nil {
|
|
return fmt.Errorf("ls backend unavailable")
|
|
}
|
|
if account == nil {
|
|
return fmt.Errorf("account is nil")
|
|
}
|
|
|
|
accountKey := strconv.FormatInt(account.ID, 10)
|
|
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
|
if accessToken == "" {
|
|
return fmt.Errorf("missing access token")
|
|
}
|
|
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
|
|
|
expiresAt := time.Time{}
|
|
if ts := account.GetCredentialAsTime("expires_at"); ts != nil {
|
|
expiresAt = ts.UTC()
|
|
}
|
|
|
|
s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
|
|
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
|
s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount)
|
|
|
|
proxyURL := ""
|
|
if account.Proxy != nil {
|
|
proxyURL = account.Proxy.URL()
|
|
}
|
|
|
|
if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil {
|
|
return fmt.Errorf("get or create ls worker: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *LSPoolBootstrapService) bootstrapConcurrency() int {
|
|
parallelism := defaultLSPoolBootstrapConcurrency
|
|
if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism {
|
|
parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive
|
|
}
|
|
if parallelism < 1 {
|
|
return 1
|
|
}
|
|
return parallelism
|
|
}
|
|
|
|
func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool {
|
|
if account == nil {
|
|
return false
|
|
}
|
|
if account.Platform != PlatformAntigravity {
|
|
return false
|
|
}
|
|
if account.Type != AccountTypeOAuth {
|
|
return false
|
|
}
|
|
if account.Status != StatusActive || !account.Schedulable {
|
|
return false
|
|
}
|
|
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(account.GetCredential("access_token")) == "" {
|
|
return false
|
|
}
|
|
return strings.TrimSpace(account.GetCredential("project_id")) != ""
|
|
}
|