Merge pull request #2297 from ZeroDeng01/dev
fix(gemini): 修复 Gemini Vertex Service Account 账号测试时,前置 OAuth token 请求没有使用账号代理的问题
This commit is contained in:
commit
348eeaa06a
@ -563,9 +563,7 @@ func normalizeCodexImportEntry(entry codexImportEntry) (*codexImportAccount, err
|
|||||||
}
|
}
|
||||||
if item.IDToken != "" {
|
if item.IDToken != "" {
|
||||||
item.Credentials["id_token"] = item.IDToken
|
item.Credentials["id_token"] = item.IDToken
|
||||||
if err := enrichCodexImportAccountFromJWT(item, item.IDToken, false, now); err != nil {
|
_ = enrichCodexImportAccountFromJWT(item, item.IDToken, false, now)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := enrichCodexImportAccountFromJWT(item, item.AccessToken, true, now); err != nil {
|
if err := enrichCodexImportAccountFromJWT(item, item.AccessToken, true, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -16,6 +16,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,7 +175,7 @@ func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key, vertexServiceAccountProxyURL(account))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -183,7 +185,36 @@ func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCa
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
|
func vertexServiceAccountProxyURL(account *Account) string {
|
||||||
|
if account == nil || account.ProxyID == nil || account.Proxy == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVertexServiceAccountHTTPClient(proxyURL string) (*http.Client, error) {
|
||||||
|
proxyURL = strings.TrimSpace(proxyURL)
|
||||||
|
if proxyURL == "" {
|
||||||
|
return &http.Client{Timeout: 15 * time.Second}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, parsedProxy, err := proxyurl.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected default transport type %T", http.DefaultTransport)
|
||||||
|
}
|
||||||
|
transport := defaultTransport.Clone()
|
||||||
|
transport.Proxy = nil
|
||||||
|
if err := proxyutil.ConfigureTransportProxy(transport, parsedProxy); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &http.Client{Timeout: 15 * time.Second, Transport: transport}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey, proxyURL string) (string, time.Duration, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"iss": key.ClientEmail,
|
"iss": key.ClientEmail,
|
||||||
@ -215,7 +246,10 @@ func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAc
|
|||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 15 * time.Second}
|
client, err := newVertexServiceAccountHTTPClient(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, fmt.Errorf("configure service account token proxy: %w", err)
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
return "", 0, fmt.Errorf("service account token request failed: %w", err)
|
||||||
|
|||||||
@ -1,8 +1,17 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@ -75,3 +84,70 @@ func TestParseVertexServiceAccountKey(t *testing.T) {
|
|||||||
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
|
require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
|
||||||
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
|
require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVertexServiceAccountProxyURL(t *testing.T) {
|
||||||
|
proxyID := int64(7)
|
||||||
|
account := &Account{
|
||||||
|
ProxyID: &proxyID,
|
||||||
|
Proxy: &Proxy{
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "proxy.example.com",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, "http://proxy.example.com:8080", vertexServiceAccountProxyURL(account))
|
||||||
|
require.Empty(t, vertexServiceAccountProxyURL(&Account{Proxy: account.Proxy}))
|
||||||
|
require.Empty(t, vertexServiceAccountProxyURL(&Account{ProxyID: &proxyID}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeVertexServiceAccountTokenUsesProxy(t *testing.T) {
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pemBytes := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
seenProxyRequest := make(chan string, 1)
|
||||||
|
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
seenProxyRequest <- r.URL.String()
|
||||||
|
require.Equal(t, "oauth2.googleapis.com", r.URL.Host)
|
||||||
|
require.Equal(t, "/token", r.URL.Path)
|
||||||
|
require.NoError(t, r.ParseForm())
|
||||||
|
require.Equal(t, "urn:ietf:params:oauth:grant-type:jwt-bearer", r.PostForm.Get("grant_type"))
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
require.NoError(t, json.NewEncoder(w).Encode(vertexTokenResponse{
|
||||||
|
AccessToken: "proxied-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
key := &vertexServiceAccountKey{
|
||||||
|
ClientEmail: "svc@example.iam.gserviceaccount.com",
|
||||||
|
PrivateKey: string(pemBytes),
|
||||||
|
TokenURI: "http://oauth2.googleapis.com/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
token, ttl, err := exchangeVertexServiceAccountToken(contextWithTestTimeout(t), key, proxy.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "proxied-token", token)
|
||||||
|
require.Equal(t, 55*time.Minute, ttl)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-seenProxyRequest:
|
||||||
|
require.Equal(t, "http://oauth2.googleapis.com/token", got)
|
||||||
|
default:
|
||||||
|
t.Fatal("token request did not reach proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextWithTestTimeout(t *testing.T) context.Context {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user