269 lines
6.5 KiB
Go
269 lines
6.5 KiB
Go
package lspool
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
|
"golang.org/x/net/proxy"
|
|
)
|
|
|
|
type lsProxyBridge struct {
|
|
listener net.Listener
|
|
server *http.Server
|
|
url string
|
|
upstream string
|
|
}
|
|
|
|
type lsProxyBridgeManager struct {
|
|
mu sync.Mutex
|
|
bridges map[string]*lsProxyBridge
|
|
logger *slog.Logger
|
|
}
|
|
|
|
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
|
|
bridges: make(map[string]*lsProxyBridge),
|
|
logger: slog.Default().With("component", "lspool-proxy-bridge"),
|
|
}
|
|
|
|
var (
|
|
lsProxyBridgeDialTimeout = 10 * time.Second
|
|
lsProxyBridgeProbeTargets = []string{
|
|
"cloudcode-pa.googleapis.com:443",
|
|
"oauthaccountmanager.googleapis.com:443",
|
|
}
|
|
)
|
|
|
|
func prepareLSProxyURL(raw string) (string, error) {
|
|
normalized, parsed, err := proxyurl.Parse(raw)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if parsed == nil {
|
|
return "", nil
|
|
}
|
|
|
|
switch strings.ToLower(parsed.Scheme) {
|
|
case "http", "https":
|
|
return normalized, nil
|
|
case "socks5", "socks5h":
|
|
return globalLSProxyBridgeManager.ensure(normalized, parsed)
|
|
default:
|
|
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
|
}
|
|
}
|
|
|
|
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if bridge := m.bridges[key]; bridge != nil {
|
|
return bridge.url, nil
|
|
}
|
|
|
|
bridge, err := newLSProxyBridge(upstream, m.logger)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
m.bridges[key] = bridge
|
|
return bridge.url, nil
|
|
}
|
|
|
|
func (m *lsProxyBridgeManager) closeAll() {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
for key, bridge := range m.bridges {
|
|
if bridge != nil {
|
|
_ = bridge.server.Close()
|
|
_ = bridge.listener.Close()
|
|
}
|
|
delete(m.bridges, key)
|
|
}
|
|
}
|
|
|
|
func closeAllLSProxyBridgesForTest() {
|
|
globalLSProxyBridgeManager.closeAll()
|
|
}
|
|
|
|
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
|
|
dialer, err := proxy.FromURL(upstream, proxy.Direct)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
|
|
}
|
|
|
|
bridge := &lsProxyBridge{
|
|
listener: listener,
|
|
url: "http://" + listener.Addr().String(),
|
|
upstream: upstream.Redacted(),
|
|
}
|
|
|
|
server := &http.Server{
|
|
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
IdleTimeout: 2 * time.Minute,
|
|
}
|
|
bridge.server = server
|
|
|
|
go func() {
|
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
|
|
}
|
|
}()
|
|
|
|
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
|
|
go bridge.probeConnectivity(dialer, logger)
|
|
return bridge, nil
|
|
}
|
|
|
|
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodConnect {
|
|
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
targetAddr := strings.TrimSpace(r.Host)
|
|
if targetAddr == "" {
|
|
targetAddr = strings.TrimSpace(r.URL.Host)
|
|
}
|
|
if targetAddr == "" {
|
|
http.Error(w, "missing target host", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
|
|
targetAddr = net.JoinHostPort(targetAddr, "443")
|
|
}
|
|
|
|
startedAt := time.Now()
|
|
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
|
|
|
|
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
|
|
defer cancel()
|
|
|
|
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
|
|
if err != nil {
|
|
logger.Warn("LS proxy bridge dial failed",
|
|
"upstream", b.upstream,
|
|
"target", targetAddr,
|
|
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
|
"err", err)
|
|
http.Error(w, "proxy dial failed", http.StatusBadGateway)
|
|
return
|
|
}
|
|
logger.Info("LS proxy bridge CONNECT established",
|
|
"upstream", b.upstream,
|
|
"target", targetAddr,
|
|
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
|
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
_ = targetConn.Close()
|
|
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
clientConn, rw, err := hijacker.Hijack()
|
|
if err != nil {
|
|
_ = targetConn.Close()
|
|
return
|
|
}
|
|
|
|
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
|
|
_ = targetConn.Close()
|
|
_ = clientConn.Close()
|
|
return
|
|
}
|
|
|
|
if rw != nil && rw.Reader.Buffered() > 0 {
|
|
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
|
|
_ = targetConn.Close()
|
|
_ = clientConn.Close()
|
|
return
|
|
}
|
|
}
|
|
|
|
tunnelConns(clientConn, targetConn)
|
|
}
|
|
}
|
|
|
|
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
|
|
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
|
return contextDialer.DialContext(ctx, "tcp", targetAddr)
|
|
}
|
|
|
|
type dialResult struct {
|
|
conn net.Conn
|
|
err error
|
|
}
|
|
resultCh := make(chan dialResult, 1)
|
|
go func() {
|
|
conn, err := dialer.Dial("tcp", targetAddr)
|
|
resultCh <- dialResult{conn: conn, err: err}
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case result := <-resultCh:
|
|
return result.conn, result.err
|
|
}
|
|
}
|
|
|
|
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
|
|
for _, targetAddr := range lsProxyBridgeProbeTargets {
|
|
startedAt := time.Now()
|
|
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
|
|
conn, err := dialViaProxy(ctx, dialer, targetAddr)
|
|
cancel()
|
|
if err != nil {
|
|
logger.Warn("LS proxy bridge probe failed",
|
|
"upstream", b.upstream,
|
|
"target", targetAddr,
|
|
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
|
"err", err)
|
|
continue
|
|
}
|
|
_ = conn.Close()
|
|
logger.Info("LS proxy bridge probe succeeded",
|
|
"upstream", b.upstream,
|
|
"target", targetAddr,
|
|
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
|
}
|
|
}
|
|
|
|
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
|
|
var once sync.Once
|
|
closeBoth := func() {
|
|
_ = clientConn.Close()
|
|
_ = targetConn.Close()
|
|
}
|
|
|
|
go func() {
|
|
_, _ = io.Copy(targetConn, clientConn)
|
|
once.Do(closeBoth)
|
|
}()
|
|
go func() {
|
|
_, _ = io.Copy(clientConn, targetConn)
|
|
once.Do(closeBoth)
|
|
}()
|
|
}
|
|
|
|
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
|
|
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
|
|
}
|