From 2a17c0b229598a9a1791da3852d68edaca6ddbf9 Mon Sep 17 00:00:00 2001 From: ZeroDeng Date: Fri, 8 May 2026 14:03:15 +0800 Subject: [PATCH] fix(gemini): route Vertex token exchange through account proxy --- .../service/vertex_service_account.go | 36 ++++++++- .../service/vertex_service_account_test.go | 76 +++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/vertex_service_account.go b/backend/internal/service/vertex_service_account.go index 4430cf81..c7a2279d 100644 --- a/backend/internal/service/vertex_service_account.go +++ b/backend/internal/service/vertex_service_account.go @@ -16,6 +16,8 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/golang-jwt/jwt/v5" ) @@ -173,7 +175,7 @@ func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCa } } - accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) + accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key, vertexServiceAccountProxyURL(account)) if err != nil { return "", err } @@ -183,7 +185,32 @@ func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCa return accessToken, nil } -func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) { +func vertexServiceAccountProxyURL(account *Account) string { + if account == nil || account.ProxyID == nil || account.Proxy == nil { + return "" + } + return account.Proxy.URL() +} + +func newVertexServiceAccountHTTPClient(proxyURL string) (*http.Client, error) { + proxyURL = strings.TrimSpace(proxyURL) + if proxyURL == "" { + return &http.Client{Timeout: 15 * time.Second}, nil + } + + _, parsedProxy, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = nil + if err := proxyutil.ConfigureTransportProxy(transport, parsedProxy); err != nil { + return nil, err + } + return &http.Client{Timeout: 15 * time.Second, Transport: transport}, nil +} + +func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey, proxyURL string) (string, time.Duration, error) { now := time.Now() claims := jwt.MapClaims{ "iss": key.ClientEmail, @@ -215,7 +242,10 @@ func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAc } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - client := &http.Client{Timeout: 15 * time.Second} + client, err := newVertexServiceAccountHTTPClient(proxyURL) + if err != nil { + return "", 0, fmt.Errorf("configure service account token proxy: %w", err) + } resp, err := client.Do(req) if err != nil { return "", 0, fmt.Errorf("service account token request failed: %w", err) diff --git a/backend/internal/service/vertex_service_account_test.go b/backend/internal/service/vertex_service_account_test.go index 519f5b2f..d77a1988 100644 --- a/backend/internal/service/vertex_service_account_test.go +++ b/backend/internal/service/vertex_service_account_test.go @@ -1,8 +1,17 @@ package service import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "net/http" + "net/http/httptest" "strings" "testing" + "time" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -75,3 +84,70 @@ func TestParseVertexServiceAccountKey(t *testing.T) { require.Equal(t, vertexDefaultTokenURL, key.TokenURI) require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY")) } + +func TestVertexServiceAccountProxyURL(t *testing.T) { + proxyID := int64(7) + account := &Account{ + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + }, + } + + require.Equal(t, "http://proxy.example.com:8080", vertexServiceAccountProxyURL(account)) + require.Empty(t, vertexServiceAccountProxyURL(&Account{Proxy: account.Proxy})) + require.Empty(t, vertexServiceAccountProxyURL(&Account{ProxyID: &proxyID})) +} + +func TestExchangeVertexServiceAccountTokenUsesProxy(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + seenProxyRequest := make(chan string, 1) + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenProxyRequest <- r.URL.String() + require.Equal(t, "oauth2.googleapis.com", r.URL.Host) + require.Equal(t, "/token", r.URL.Path) + require.NoError(t, r.ParseForm()) + require.Equal(t, "urn:ietf:params:oauth:grant-type:jwt-bearer", r.PostForm.Get("grant_type")) + + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(vertexTokenResponse{ + AccessToken: "proxied-token", + TokenType: "Bearer", + ExpiresIn: 3600, + })) + })) + defer proxy.Close() + + key := &vertexServiceAccountKey{ + ClientEmail: "svc@example.iam.gserviceaccount.com", + PrivateKey: string(pemBytes), + TokenURI: "http://oauth2.googleapis.com/token", + } + + token, ttl, err := exchangeVertexServiceAccountToken(contextWithTestTimeout(t), key, proxy.URL) + require.NoError(t, err) + require.Equal(t, "proxied-token", token) + require.Equal(t, 55*time.Minute, ttl) + + select { + case got := <-seenProxyRequest: + require.Equal(t, "http://oauth2.googleapis.com/token", got) + default: + t.Fatal("token request did not reach proxy") + } +} + +func contextWithTestTimeout(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + return ctx +}