sub2api/backend/internal/pkg/lspool/worker_manager.go
win 0cda0e0b96
Some checks failed
CI / test (push) Failing after 8s
CI / golangci-lint (push) Failing after 5s
Security Scan / backend-security (push) Failing after 7s
Security Scan / frontend-security (push) Failing after 6s
feat: add dockerized antigravity ls worker mode
2026-03-30 23:57:25 +08:00

650 lines
18 KiB
Go

package lspool
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)
const (
lsWorkerManagedByLabel = "managed-by"
lsWorkerManagedByValue = "sub2api"
lsWorkerAccountLabel = "account_id"
lsWorkerProxyHashLabel = "proxy_hash"
lsWorkerImageTagLabel = "image_tag"
lsWorkerControlPort = 18081
)
type workerManagerConfig struct {
Image string
Network string
DockerSocket string
IdleTTL time.Duration
MaxActive int
StartupTimeout time.Duration
RequestTimeout time.Duration
}
type dockerClient interface {
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
Close() error
}
type workerManager struct {
cfg workerManagerConfig
docker dockerClient
http *http.Client
mu sync.Mutex
workers map[string]*workerHandle
state map[string]*workerAccountState
ctx context.Context
cancel context.CancelFunc
logger *slog.Logger
}
type workerHandle struct {
Key string
AccountID string
ProxyURL string
ProxyHash string
ContainerID string
Container string
Address string
AuthToken string
LastUsed time.Time
LastStateSHA string
}
type workerAccountState struct {
HasToken bool `json:"has_token"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
HasModelCredits bool `json:"has_model_credits"`
UseAICredits bool `json:"use_ai_credits"`
AvailableCredits *int32 `json:"available_credits,omitempty"`
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
}
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
managerCfg := workerManagerConfig{
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
}
if managerCfg.Image == "" {
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
}
if managerCfg.Network == "" {
managerCfg.Network = "sub2api-network"
}
if managerCfg.DockerSocket == "" {
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
}
if managerCfg.IdleTTL <= 0 {
managerCfg.IdleTTL = 15 * time.Minute
}
if managerCfg.MaxActive < 1 {
managerCfg.MaxActive = 50
}
if managerCfg.StartupTimeout <= 0 {
managerCfg.StartupTimeout = 45 * time.Second
}
if managerCfg.RequestTimeout <= 0 {
managerCfg.RequestTimeout = 60 * time.Second
}
opts := []client.Opt{client.WithAPIVersionNegotiation()}
if managerCfg.DockerSocket != "" {
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
} else {
opts = append(opts, client.FromEnv)
}
dockerClient, err := client.NewClientWithOpts(opts...)
if err != nil {
return nil, fmt.Errorf("create docker client: %w", err)
}
return newWorkerManager(managerCfg, dockerClient)
}
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
ctx, cancel := context.WithCancel(context.Background())
mgr := &workerManager{
cfg: cfg,
docker: docker,
http: &http.Client{
Timeout: cfg.RequestTimeout,
Transport: &http.Transport{
Proxy: nil,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConnsPerHost: 8,
},
},
workers: make(map[string]*workerHandle),
state: make(map[string]*workerAccountState),
ctx: ctx,
cancel: cancel,
logger: slog.Default().With("component", "lspool-worker-manager"),
}
if err := mgr.reconcileManagedContainers(ctx); err != nil {
cancel()
_ = docker.Close()
return nil, err
}
go mgr.cleanupLoop()
return mgr, nil
}
func (m *workerManager) Close() {
m.cancel()
m.mu.Lock()
workers := make([]*workerHandle, 0, len(m.workers))
for _, handle := range m.workers {
workers = append(workers, handle)
}
m.workers = make(map[string]*workerHandle)
m.mu.Unlock()
for _, handle := range workers {
m.removeWorkerContainer(context.Background(), handle)
}
if m.docker != nil {
_ = m.docker.Close()
}
}
func (m *workerManager) Stats() map[string]any {
m.mu.Lock()
defer m.mu.Unlock()
return map[string]any{
"accounts": len(m.state),
"total": len(m.workers),
"active": len(m.workers),
}
}
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasToken = true
state.AccessToken = accessToken
state.RefreshToken = refreshToken
if expiresAt.IsZero() {
state.ExpiresAt = nil
} else {
ts := expiresAt.UTC()
state.ExpiresAt = &ts
}
}
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasModelCredits = true
state.UseAICredits = useAICredits
state.AvailableCredits = cloneInt32Ptr(availableCredits)
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
}
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
rawProxy := ""
if len(proxyURL) > 0 {
rawProxy = proxyURL[0]
}
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
if err != nil {
return nil, err
}
if parsedProxy == nil {
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
}
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
proxyHash := proxyHash(normalizedProxy)
workerKey := buildWorkerKey(accountID, proxyHash)
m.mu.Lock()
state := cloneWorkerAccountState(m.state[accountID])
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
}
handle := m.workers[workerKey]
if handle == nil {
if len(m.workers) >= m.cfg.MaxActive {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
}
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
if err != nil {
m.mu.Unlock()
return nil, err
}
m.workers[workerKey] = handle
}
handle.LastUsed = time.Now()
m.mu.Unlock()
if err := m.waitForWorkerHealthy(handle); err != nil {
return nil, err
}
if err := m.syncWorkerState(handle, state); err != nil {
return nil, err
}
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
return nil, err
}
inst := &Instance{
AccountID: accountID,
Replica: replica,
Address: handle.Address,
client: m.http,
healthy: true,
lastUsed: time.Now(),
modelMapReady: 1,
remote: true,
workerToken: handle.AuthToken,
routingKey: routingKey,
}
return inst, nil
}
func (m *workerManager) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.collectIdleWorkers()
}
}
}
func (m *workerManager) collectIdleWorkers() {
now := time.Now()
var expired []*workerHandle
m.mu.Lock()
for key, handle := range m.workers {
if handle == nil {
delete(m.workers, key)
continue
}
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
continue
}
expired = append(expired, handle)
delete(m.workers, key)
}
m.mu.Unlock()
for _, handle := range expired {
m.removeWorkerContainer(context.Background(), handle)
}
}
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
All: true,
Filters: args,
})
if err != nil {
return fmt.Errorf("list managed ls workers: %w", err)
}
for _, summary := range containers {
handle := &workerHandle{
ContainerID: summary.ID,
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
}
if err := m.removeWorkerContainer(ctx, handle); err != nil {
return err
}
}
return nil
}
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
authToken := generateUUID()
proxyHost := parsedProxy.Hostname()
proxyPort := parsedProxy.Port()
if proxyPort == "" {
proxyPort = "1080"
}
proxyUser := parsedProxy.User.Username()
proxyPass, _ := parsedProxy.User.Password()
labels := map[string]string{
lsWorkerManagedByLabel: lsWorkerManagedByValue,
lsWorkerAccountLabel: accountID,
lsWorkerProxyHashLabel: proxyHash,
lsWorkerImageTagLabel: m.cfg.Image,
}
env := []string{
"ANTIGRAVITY_APP_ROOT=/app/ls",
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
}
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
env = append(env, "TZ="+tz)
}
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
Image: m.cfg.Image,
Labels: labels,
Env: env,
}, &container.HostConfig{
CapAdd: []string{"NET_ADMIN"},
}, &network.NetworkingConfig{
EndpointsConfig: map[string]*network.EndpointSettings{
m.cfg.Network: {},
},
}, nil, containerName)
if err != nil {
return nil, fmt.Errorf("create ls worker container: %w", err)
}
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("start ls worker container: %w", err)
}
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("inspect ls worker container: %w", err)
}
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, err
}
m.logger.Info("created ls worker",
"account", shortAccountID(accountID),
"container", containerName,
"address", address,
"proxy_hash", proxyHash[:8])
return &workerHandle{
Key: buildWorkerKey(accountID, proxyHash),
AccountID: accountID,
ProxyURL: proxyURL,
ProxyHash: proxyHash,
ContainerID: createResp.ID,
Container: containerName,
Address: address,
AuthToken: authToken,
LastUsed: time.Now(),
}, nil
}
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
if inspect.NetworkSettings == nil {
return "", fmt.Errorf("ls worker inspect missing network settings")
}
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
for _, endpoint := range inspect.NetworkSettings.Networks {
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
}
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
}
func firstContainerName(names []string) string {
if len(names) == 0 {
return ""
}
return names[0]
}
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
_ = resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
}
select {
case <-ctx.Done():
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
values := url.Values{}
if strings.TrimSpace(routingKey) != "" {
values.Set("routing_key", routingKey)
}
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if len(body) > 0 {
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
}
}
select {
case <-ctx.Done():
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
if state == nil {
return fmt.Errorf("ls worker state is nil")
}
body, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal worker state: %w", err)
}
sum := fmt.Sprintf("%x", sha256.Sum256(body))
if handle.LastStateSHA == sum {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err != nil {
return fmt.Errorf("sync worker state: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
}
handle.LastStateSHA = sum
return nil
}
func workerURL(handle *workerHandle, path string, values url.Values) string {
base := url.URL{
Scheme: "http",
Host: handle.Address,
Path: path,
}
if values != nil {
base.RawQuery = values.Encode()
}
return base.String()
}
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
return nil
}
timeout := 5
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
}
return nil
}
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
state := m.state[accountID]
if state == nil {
state = &workerAccountState{}
m.state[accountID] = state
}
return state
}
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
resolved := resolveLSProxy(proxyURL)
normalized, parsed, err := proxyurl.Parse(resolved)
if err != nil {
return "", nil, err
}
if parsed == nil {
return "", nil, nil
}
switch strings.ToLower(parsed.Scheme) {
case "socks5", "socks5h":
return normalized, parsed, nil
default:
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
}
}
func proxyHash(proxyURL string) string {
if strings.TrimSpace(proxyURL) == "" {
return "direct"
}
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
return fmt.Sprintf("%x", sum[:])
}
func buildWorkerKey(accountID, proxyHash string) string {
return accountID + ":" + proxyHash
}
func cloneInt32Ptr(v *int32) *int32 {
if v == nil {
return nil
}
cp := *v
return &cp
}
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
if state == nil {
return nil
}
cp := *state
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
if state.ExpiresAt != nil {
ts := *state.ExpiresAt
cp.ExpiresAt = &ts
}
return &cp
}