package lspool import ( "bufio" "context" "fmt" "io" "log/slog" "net" "net/http" "net/url" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "golang.org/x/net/proxy" ) type lsProxyBridge struct { listener net.Listener server *http.Server url string upstream string } type lsProxyBridgeManager struct { mu sync.Mutex bridges map[string]*lsProxyBridge logger *slog.Logger } var globalLSProxyBridgeManager = &lsProxyBridgeManager{ bridges: make(map[string]*lsProxyBridge), logger: slog.Default().With("component", "lspool-proxy-bridge"), } var ( lsProxyBridgeDialTimeout = 10 * time.Second lsProxyBridgeProbeTargets = []string{ "cloudcode-pa.googleapis.com:443", "oauthaccountmanager.googleapis.com:443", } ) func prepareLSProxyURL(raw string) (string, error) { normalized, parsed, err := proxyurl.Parse(raw) if err != nil { return "", err } if parsed == nil { return "", nil } switch strings.ToLower(parsed.Scheme) { case "http", "https": return normalized, nil case "socks5", "socks5h": return globalLSProxyBridgeManager.ensure(normalized, parsed) default: return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme) } } func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) { m.mu.Lock() defer m.mu.Unlock() if bridge := m.bridges[key]; bridge != nil { return bridge.url, nil } bridge, err := newLSProxyBridge(upstream, m.logger) if err != nil { return "", err } m.bridges[key] = bridge return bridge.url, nil } func (m *lsProxyBridgeManager) closeAll() { m.mu.Lock() defer m.mu.Unlock() for key, bridge := range m.bridges { if bridge != nil { _ = bridge.server.Close() _ = bridge.listener.Close() } delete(m.bridges, key) } } func closeAllLSProxyBridgesForTest() { globalLSProxyBridgeManager.closeAll() } func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) { dialer, err := proxy.FromURL(upstream, proxy.Direct) if err != nil { return nil, fmt.Errorf("create SOCKS dialer: %w", err) } listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("listen LS proxy bridge: %w", err) } bridge := &lsProxyBridge{ listener: listener, url: "http://" + listener.Addr().String(), upstream: upstream.Redacted(), } server := &http.Server{ Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)), ReadHeaderTimeout: 10 * time.Second, IdleTimeout: 2 * time.Minute, } bridge.server = server go func() { if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err) } }() logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url) go bridge.probeConnectivity(dialer, logger) return bridge, nil } func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodConnect { http.Error(w, "CONNECT only", http.StatusMethodNotAllowed) return } targetAddr := strings.TrimSpace(r.Host) if targetAddr == "" { targetAddr = strings.TrimSpace(r.URL.Host) } if targetAddr == "" { http.Error(w, "missing target host", http.StatusBadRequest) return } if _, _, err := net.SplitHostPort(targetAddr); err != nil { targetAddr = net.JoinHostPort(targetAddr, "443") } startedAt := time.Now() logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr) dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout) defer cancel() targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr) if err != nil { logger.Warn("LS proxy bridge dial failed", "upstream", b.upstream, "target", targetAddr, "elapsed", time.Since(startedAt).Truncate(time.Millisecond), "err", err) http.Error(w, "proxy dial failed", http.StatusBadGateway) return } logger.Info("LS proxy bridge CONNECT established", "upstream", b.upstream, "target", targetAddr, "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) hijacker, ok := w.(http.Hijacker) if !ok { _ = targetConn.Close() http.Error(w, "hijack unsupported", http.StatusInternalServerError) return } clientConn, rw, err := hijacker.Hijack() if err != nil { _ = targetConn.Close() return } if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { _ = targetConn.Close() _ = clientConn.Close() return } if rw != nil && rw.Reader.Buffered() > 0 { if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil { _ = targetConn.Close() _ = clientConn.Close() return } } tunnelConns(clientConn, targetConn) } } func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) { if contextDialer, ok := dialer.(proxy.ContextDialer); ok { return contextDialer.DialContext(ctx, "tcp", targetAddr) } type dialResult struct { conn net.Conn err error } resultCh := make(chan dialResult, 1) go func() { conn, err := dialer.Dial("tcp", targetAddr) resultCh <- dialResult{conn: conn, err: err} }() select { case <-ctx.Done(): return nil, ctx.Err() case result := <-resultCh: return result.conn, result.err } } func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) { for _, targetAddr := range lsProxyBridgeProbeTargets { startedAt := time.Now() ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout) conn, err := dialViaProxy(ctx, dialer, targetAddr) cancel() if err != nil { logger.Warn("LS proxy bridge probe failed", "upstream", b.upstream, "target", targetAddr, "elapsed", time.Since(startedAt).Truncate(time.Millisecond), "err", err) continue } _ = conn.Close() logger.Info("LS proxy bridge probe succeeded", "upstream", b.upstream, "target", targetAddr, "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) } } func tunnelConns(clientConn net.Conn, targetConn net.Conn) { var once sync.Once closeBoth := func() { _ = clientConn.Close() _ = targetConn.Close() } go func() { _, _ = io.Copy(targetConn, clientConn) once.Do(closeBoth) }() go func() { _, _ = io.Copy(clientConn, targetConn) once.Do(closeBoth) }() } func readConnectResponse(br *bufio.Reader) (*http.Response, error) { return http.ReadResponse(br, &http.Request{Method: http.MethodConnect}) }