Merge pull request #2554 from Arron196/feature/sync-upstream-models-pr
feat: 支持从上游同步账号可用模型列表
This commit is contained in:
commit
03473d3ee8
@ -1994,6 +1994,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// SyncUpstreamModels handles syncing live supported models from an account's upstream.
|
||||
// POST /api/v1/admin/accounts/:id/models/sync-upstream
|
||||
func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
if h.accountTestService == nil {
|
||||
response.InternalError(c, "Account test service is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
var syncErr *service.UpstreamModelSyncError
|
||||
if errors.As(err, &syncErr) {
|
||||
switch syncErr.Kind {
|
||||
case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported:
|
||||
response.BadRequest(c, syncErr.SafeMessage())
|
||||
default:
|
||||
slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind)
|
||||
response.Error(c, http.StatusBadGateway, syncErr.SafeMessage())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Warn("sync_upstream_models_failed", "account_id", accountID)
|
||||
response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"models": models})
|
||||
}
|
||||
|
||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||
|
||||
@ -3,10 +3,14 @@ package admin
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -33,6 +37,39 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
type syncUpstreamHTTPUpstream struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
if u.err != nil {
|
||||
return nil, u.err
|
||||
}
|
||||
return u.resp, nil
|
||||
}
|
||||
|
||||
func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
accountTestSvc := service.NewAccountTestService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
upstream,
|
||||
&config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
nil,
|
||||
)
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
@ -103,3 +140,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "openai-apikey-missing-key",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "No OpenAI API key is available")
|
||||
}
|
||||
|
||||
func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "openai-apikey-upstream-error",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "openai-key",
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
|
||||
}}
|
||||
router := setupSyncUpstreamModelsRouter(svc, upstream)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502")
|
||||
require.NotContains(t, rec.Body.String(), "SECRET_TOKEN")
|
||||
}
|
||||
|
||||
@ -254,6 +254,8 @@ const (
|
||||
proxyTLSHandshakeTimeout = 5 * time.Second
|
||||
// clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body)
|
||||
clientTimeout = 10 * time.Second
|
||||
// fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use.
|
||||
fetchAvailableModelsBodyLimit int64 = 8 << 20
|
||||
)
|
||||
|
||||
func NewClient(proxyURL string) (*Client, error) {
|
||||
@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct {
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, nil, errors.New("antigravity client is not configured")
|
||||
}
|
||||
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
fetchClient := c.fetchAvailableModelsHTTPClient()
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||
@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
resp, err := fetchClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1))
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit {
|
||||
return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
func (c *Client) fetchAvailableModelsHTTPClient() *http.Client {
|
||||
fetchClient := *c.httpClient
|
||||
fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect
|
||||
return &fetchClient
|
||||
}
|
||||
|
||||
func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
if req == nil || req.URL == nil {
|
||||
return errors.New("redirect url is nil")
|
||||
}
|
||||
if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) {
|
||||
return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllowedFetchAvailableModelsRedirectHost(host string) bool {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
for _, baseURL := range BaseURLs {
|
||||
parsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(host, parsed.Hostname()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ── Privacy API ──────────────────────────────────────────────────────
|
||||
|
||||
// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)
|
||||
|
||||
@ -303,6 +303,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/:id/models/sync-upstream", h.Admin.Account.SyncUpstreamModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.GET("/data", h.Admin.Account.ExportData)
|
||||
accounts.POST("/data", h.Admin.Account.ImportData)
|
||||
|
||||
474
backend/internal/service/upstream_models.go
Normal file
474
backend/internal/service/upstream_models.go
Normal file
@ -0,0 +1,474 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
const upstreamModelsBodyLimit int64 = 8 << 20
|
||||
|
||||
// UpstreamModelSyncErrorKind classifies model sync failures for safe HTTP mapping.
|
||||
type UpstreamModelSyncErrorKind string
|
||||
|
||||
const (
|
||||
// UpstreamModelSyncErrorConfiguration means the account or server configuration cannot perform the sync.
|
||||
UpstreamModelSyncErrorConfiguration UpstreamModelSyncErrorKind = "configuration"
|
||||
// UpstreamModelSyncErrorUnsupported means the account format is intentionally unsupported for live model sync.
|
||||
UpstreamModelSyncErrorUnsupported UpstreamModelSyncErrorKind = "unsupported"
|
||||
// UpstreamModelSyncErrorUpstream means the configured upstream failed or returned an unusable response.
|
||||
UpstreamModelSyncErrorUpstream UpstreamModelSyncErrorKind = "upstream"
|
||||
)
|
||||
|
||||
// UpstreamModelSyncError keeps internal failure details wrapped while exposing a safe client message.
|
||||
type UpstreamModelSyncError struct {
|
||||
Kind UpstreamModelSyncErrorKind
|
||||
Message string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *UpstreamModelSyncError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Err == nil {
|
||||
return e.Message
|
||||
}
|
||||
return e.Message + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *UpstreamModelSyncError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// SafeMessage returns the sanitized message that can be sent to API clients.
|
||||
func (e *UpstreamModelSyncError) SafeMessage() string {
|
||||
if e == nil || strings.TrimSpace(e.Message) == "" {
|
||||
return "Failed to sync upstream models"
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func newUpstreamModelSyncConfigError(message string, err error) error {
|
||||
return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorConfiguration, Message: message, Err: err}
|
||||
}
|
||||
|
||||
func newUpstreamModelSyncUnsupportedError(message string, err error) error {
|
||||
return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorUnsupported, Message: message, Err: err}
|
||||
}
|
||||
|
||||
func newUpstreamModelSyncUpstreamError(message string, err error) error {
|
||||
return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorUpstream, Message: message, Err: err}
|
||||
}
|
||||
|
||||
// FetchUpstreamSupportedModels fetches the live model list from the account's upstream API format.
|
||||
func (s *AccountTestService) FetchUpstreamSupportedModels(ctx context.Context, account *Account) ([]string, error) {
|
||||
if s == nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Account test service is not configured", nil)
|
||||
}
|
||||
if account == nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Account is required", nil)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity && account.Type != AccountTypeAPIKey {
|
||||
return s.fetchAntigravityOAuthUpstreamModels(ctx, account)
|
||||
}
|
||||
|
||||
if s.httpUpstream == nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Upstream HTTP client is not configured", nil)
|
||||
}
|
||||
|
||||
req, err := s.buildUpstreamModelsRequest(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyURL := upstreamModelsProxyURL(account)
|
||||
resp, err := s.doUpstreamModelsRequest(req, proxyURL, account)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Failed to request upstream model list", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, upstreamModelsBodyLimit+1))
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Failed to read upstream model list", err)
|
||||
}
|
||||
if int64(len(body)) > upstreamModelsBodyLimit {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Upstream model list response is too large", fmt.Errorf("response exceeds %d bytes", upstreamModelsBodyLimit))
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, newUpstreamModelSyncUpstreamError(
|
||||
fmt.Sprintf("Upstream model list request failed with HTTP %d", resp.StatusCode),
|
||||
fmt.Errorf("upstream model list returned HTTP %d", resp.StatusCode),
|
||||
)
|
||||
}
|
||||
|
||||
models, err := extractUpstreamModelIDs(body)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Upstream model list response was not valid JSON", err)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Upstream returned no supported models", nil)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
|
||||
switch {
|
||||
case account.Platform == PlatformAntigravity:
|
||||
return s.buildAntigravityAPIKeyModelsRequest(ctx, account)
|
||||
case account.IsOpenAI():
|
||||
return s.buildOpenAIUpstreamModelsRequest(ctx, account)
|
||||
case account.IsGemini():
|
||||
return s.buildGeminiUpstreamModelsRequest(ctx, account)
|
||||
case account.IsAnthropic():
|
||||
return s.buildAnthropicUpstreamModelsRequest(ctx, account)
|
||||
default:
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
fmt.Sprintf("Unsupported platform for upstream model sync: %s", account.Platform), nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildAnthropicUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
|
||||
if account.IsBedrock() || account.Type == AccountTypeServiceAccount {
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
fmt.Sprintf("Unsupported Anthropic account type for upstream model sync: %s", account.Type), nil,
|
||||
)
|
||||
}
|
||||
|
||||
baseURL := "https://api.anthropic.com"
|
||||
authHeaderName := ""
|
||||
authHeaderValue := ""
|
||||
betaHeader := ""
|
||||
|
||||
if account.IsOAuth() {
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" && s.claudeTokenProvider != nil {
|
||||
token, tokenErr := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if tokenErr != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Failed to get Anthropic access token", tokenErr)
|
||||
}
|
||||
accessToken = strings.TrimSpace(token)
|
||||
}
|
||||
if accessToken == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No Anthropic access token is available", nil)
|
||||
}
|
||||
authHeaderName = "Authorization"
|
||||
authHeaderValue = "Bearer " + accessToken
|
||||
betaHeader = claude.DefaultBetaHeader
|
||||
} else if account.Type == AccountTypeAPIKey {
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No Anthropic API key is available", nil)
|
||||
}
|
||||
baseURL = account.GetBaseURL()
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
authHeaderName = "x-api-key"
|
||||
authHeaderValue = apiKey
|
||||
betaHeader = claude.APIKeyBetaHeader
|
||||
} else {
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
fmt.Sprintf("Unsupported Anthropic account type for upstream model sync: %s", account.Type), nil,
|
||||
)
|
||||
}
|
||||
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid Anthropic base URL", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildV1ModelsURL(normalizedBaseURL), nil)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid Anthropic model list URL", err)
|
||||
}
|
||||
for key, value := range claude.DefaultHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("anthropic-beta", betaHeader)
|
||||
req.Header.Set(authHeaderName, authHeaderValue)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildAntigravityAPIKeyModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
fmt.Sprintf("Unsupported Antigravity account type for upstream model sync: %s", account.Type), nil,
|
||||
)
|
||||
}
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No Antigravity API key is available", nil)
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
|
||||
if baseURL == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("Antigravity API-key base URL is required for upstream model sync", nil)
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(baseURL), "/antigravity") {
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
"Antigravity API-key upstream model sync requires a compatible gateway base URL ending in /antigravity; use Antigravity OAuth for official Cloud Code upstreams",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid Antigravity base URL", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildV1ModelsURL(normalizedBaseURL), nil)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid Antigravity model list URL", err)
|
||||
}
|
||||
for key, value := range claude.DefaultHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader)
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildOpenAIUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
fmt.Sprintf("Unsupported OpenAI account type for upstream model sync: %s", account.Type), nil,
|
||||
)
|
||||
}
|
||||
apiKey := strings.TrimSpace(account.GetOpenAIApiKey())
|
||||
if apiKey == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No OpenAI API key is available", nil)
|
||||
}
|
||||
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid OpenAI base URL", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildOpenAIModelsURL(normalizedBaseURL), nil)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid OpenAI model list URL", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) buildGeminiUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
|
||||
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid Gemini base URL", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildGeminiModelsURL(normalizedBaseURL), nil)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Invalid Gemini model list URL", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No Gemini API key is available", nil)
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case AccountTypeOAuth:
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
return nil, newUpstreamModelSyncUnsupportedError("Gemini Code Assist model listing is not supported by this sync button", nil)
|
||||
}
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Gemini token provider is not configured", nil)
|
||||
}
|
||||
accessToken, tokenErr := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||
if tokenErr != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Failed to get Gemini access token", tokenErr)
|
||||
}
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No Gemini access token is available", nil)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
default:
|
||||
return nil, newUpstreamModelSyncUnsupportedError(
|
||||
fmt.Sprintf("Unsupported Gemini account type for upstream model sync: %s", account.Type), nil,
|
||||
)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) fetchAntigravityOAuthUpstreamModels(ctx context.Context, account *Account) ([]string, error) {
|
||||
if s.antigravityGatewayService == nil || s.antigravityGatewayService.GetTokenProvider() == nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Antigravity token provider is not configured", nil)
|
||||
}
|
||||
|
||||
accessToken, err := s.antigravityGatewayService.GetTokenProvider().GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Failed to get Antigravity access token", err)
|
||||
}
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return nil, newUpstreamModelSyncConfigError("No Antigravity access token is available", nil)
|
||||
}
|
||||
|
||||
client, err := antigravity.NewClient(upstreamModelsProxyURL(account))
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncConfigError("Failed to configure Antigravity client", err)
|
||||
}
|
||||
modelsResp, _, err := client.FetchAvailableModels(ctx, accessToken, strings.TrimSpace(account.GetCredential("project_id")))
|
||||
if err != nil {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Failed to fetch Antigravity available models", err)
|
||||
}
|
||||
if modelsResp == nil || len(modelsResp.Models) == 0 {
|
||||
return nil, newUpstreamModelSyncUpstreamError("Upstream returned no supported models", nil)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(modelsResp.Models))
|
||||
for modelID := range modelsResp.Models {
|
||||
models = append(models, strings.TrimSpace(modelID))
|
||||
}
|
||||
return dedupeAndSortModelIDs(models), nil
|
||||
}
|
||||
|
||||
func (s *AccountTestService) doUpstreamModelsRequest(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
if s.tlsFPProfileService == nil {
|
||||
return s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, nil)
|
||||
}
|
||||
return s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
}
|
||||
|
||||
func upstreamModelsProxyURL(account *Account) string {
|
||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildV1ModelsURL(base string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if strings.HasSuffix(normalized, "/v1/models") {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(normalized, "/v1") {
|
||||
return normalized + "/models"
|
||||
}
|
||||
return normalized + "/v1/models"
|
||||
}
|
||||
|
||||
func buildOpenAIModelsURL(base string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if strings.HasSuffix(normalized, "/v1/models") {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(normalized, "/v1") {
|
||||
return normalized + "/models"
|
||||
}
|
||||
return normalized + "/v1/models"
|
||||
}
|
||||
|
||||
func buildGeminiModelsURL(base string) string {
|
||||
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
||||
if strings.HasSuffix(normalized, "/v1beta/models") {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(normalized, "/v1beta") {
|
||||
return normalized + "/models"
|
||||
}
|
||||
return normalized + "/v1beta/models"
|
||||
}
|
||||
|
||||
type upstreamModelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func extractUpstreamModelIDs(body []byte) ([]string, error) {
|
||||
var response struct {
|
||||
Data []upstreamModelEntry `json:"data"`
|
||||
Models []upstreamModelEntry `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
var arrayResponse []upstreamModelEntry
|
||||
if arrayErr := json.Unmarshal(body, &arrayResponse); arrayErr != nil {
|
||||
return nil, fmt.Errorf("parse upstream model list: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(arrayResponse))
|
||||
for _, entry := range arrayResponse {
|
||||
models = append(models, upstreamModelEntryID(entry))
|
||||
}
|
||||
return dedupeAndSortModelIDs(models), nil
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(response.Data)+len(response.Models))
|
||||
for _, entry := range response.Data {
|
||||
models = append(models, upstreamModelEntryID(entry))
|
||||
}
|
||||
for _, entry := range response.Models {
|
||||
models = append(models, upstreamModelEntryID(entry))
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
var arrayResponse []upstreamModelEntry
|
||||
if err := json.Unmarshal(body, &arrayResponse); err == nil {
|
||||
for _, entry := range arrayResponse {
|
||||
models = append(models, upstreamModelEntryID(entry))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dedupeAndSortModelIDs(models), nil
|
||||
}
|
||||
|
||||
func upstreamModelEntryID(entry upstreamModelEntry) string {
|
||||
modelID := strings.TrimSpace(entry.ID)
|
||||
if modelID == "" {
|
||||
modelID = strings.TrimSpace(entry.Name)
|
||||
}
|
||||
return strings.TrimPrefix(modelID, "models/")
|
||||
}
|
||||
|
||||
func dedupeAndSortModelIDs(models []string) []string {
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
result := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[model]; exists {
|
||||
continue
|
||||
}
|
||||
seen[model] = struct{}{}
|
||||
result = append(result, model)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
226
backend/internal/service/upstream_models_test.go
Normal file
226
backend/internal/service/upstream_models_test.go
Normal file
@ -0,0 +1,226 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func upstreamModelSyncTestConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildV1ModelsURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com"))
|
||||
require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com/v1"))
|
||||
require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com/v1/models"))
|
||||
require.Equal(t, "https://gateway.example.com/antigravity/v1/models", buildV1ModelsURL("https://gateway.example.com/antigravity/"))
|
||||
}
|
||||
|
||||
func TestBuildGeminiModelsURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com"))
|
||||
require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com/v1beta"))
|
||||
require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com/v1beta/models"))
|
||||
}
|
||||
|
||||
func TestExtractUpstreamModelIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "openai and anthropic data array",
|
||||
body: `{"data":[{"id":"claude-sonnet-4-5"},{"id":"gpt-5"},{"id":"gpt-5"},{"id":""}]}`,
|
||||
want: []string{"claude-sonnet-4-5", "gpt-5"},
|
||||
},
|
||||
{
|
||||
name: "gemini models array strips prefix",
|
||||
body: `{"models":[{"name":"models/gemini-2.5-pro"},{"name":"gemini-2.5-flash"}]}`,
|
||||
want: []string{"gemini-2.5-flash", "gemini-2.5-pro"},
|
||||
},
|
||||
{
|
||||
name: "top level array",
|
||||
body: `[{"id":"z-model"},{"name":"models/a-model"}]`,
|
||||
want: []string{"a-model", "z-model"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := extractUpstreamModelIDs([]byte(tt.body))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUpstreamModelsRequestsForAPIKeyAccounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()}
|
||||
ctx := context.Background()
|
||||
|
||||
anthropicReq, err := svc.buildAnthropicUpstreamModelsRequest(ctx, &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "anthropic-key",
|
||||
"base_url": "https://anthropic.example.com/v1",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://anthropic.example.com/v1/models", anthropicReq.URL.String())
|
||||
require.Equal(t, "anthropic-key", anthropicReq.Header.Get("x-api-key"))
|
||||
require.Equal(t, "2023-06-01", anthropicReq.Header.Get("anthropic-version"))
|
||||
|
||||
openAIReq, err := svc.buildOpenAIUpstreamModelsRequest(ctx, &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "openai-key",
|
||||
"base_url": "https://openai.example.com",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://openai.example.com/v1/models", openAIReq.URL.String())
|
||||
require.Equal(t, "Bearer openai-key", openAIReq.Header.Get("Authorization"))
|
||||
|
||||
geminiReq, err := svc.buildGeminiUpstreamModelsRequest(ctx, &Account{
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", geminiReq.URL.String())
|
||||
require.Equal(t, "gemini-key", geminiReq.Header.Get("x-goog-api-key"))
|
||||
|
||||
antigravityReq, err := svc.buildAntigravityAPIKeyModelsRequest(ctx, &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "antigravity-key",
|
||||
"base_url": "https://gateway.example.com/antigravity",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://gateway.example.com/antigravity/v1/models", antigravityReq.URL.String())
|
||||
require.Equal(t, "antigravity-key", antigravityReq.Header.Get("x-api-key"))
|
||||
}
|
||||
|
||||
func TestBuildAntigravityAPIKeyModelsRequestRejectsOfficialCloudCodeBase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()}
|
||||
_, err := svc.buildAntigravityAPIKeyModelsRequest(context.Background(), &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "antigravity-key",
|
||||
"base_url": "https://cloudcode-pa.googleapis.com",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
var syncErr *UpstreamModelSyncError
|
||||
require.True(t, errors.As(err, &syncErr))
|
||||
require.Equal(t, UpstreamModelSyncErrorUnsupported, syncErr.Kind)
|
||||
require.Contains(t, syncErr.SafeMessage(), "compatible gateway")
|
||||
}
|
||||
|
||||
func TestBuildAnthropicUpstreamModelsRequestRejectsBedrock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()}
|
||||
_, err := svc.buildAnthropicUpstreamModelsRequest(context.Background(), &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
var syncErr *UpstreamModelSyncError
|
||||
require.True(t, errors.As(err, &syncErr))
|
||||
require.Equal(t, UpstreamModelSyncErrorUnsupported, syncErr.Kind)
|
||||
}
|
||||
|
||||
func TestFetchUpstreamSupportedModelsParsesOpenAIResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"data":[{"id":"gpt-5"},{"id":"gpt-5"},{"name":"o3"}]}`)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: upstreamModelSyncTestConfig(),
|
||||
}
|
||||
|
||||
models, err := svc.FetchUpstreamSupportedModels(context.Background(), &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "openai-key",
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"gpt-5", "o3"}, models)
|
||||
require.Equal(t, "https://openai.example.com/v1/models", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer openai-key", upstream.lastReq.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func TestFetchUpstreamSupportedModelsDoesNotExposeUpstreamBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: upstreamModelSyncTestConfig(),
|
||||
}
|
||||
|
||||
_, err := svc.FetchUpstreamSupportedModels(context.Background(), &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "openai-key",
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.NotContains(t, err.Error(), "SECRET_TOKEN")
|
||||
|
||||
var syncErr *UpstreamModelSyncError
|
||||
require.True(t, errors.As(err, &syncErr))
|
||||
require.Equal(t, UpstreamModelSyncErrorUpstream, syncErr.Kind)
|
||||
require.NotContains(t, syncErr.SafeMessage(), "SECRET_TOKEN")
|
||||
require.Contains(t, syncErr.SafeMessage(), "HTTP 502")
|
||||
}
|
||||
@ -446,6 +446,20 @@ export async function getAvailableModels(id: number): Promise<ClaudeModel[]> {
|
||||
return data
|
||||
}
|
||||
|
||||
export interface SyncUpstreamModelsResult {
|
||||
models: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Sync live supported models from the account's upstream model-list endpoint
|
||||
* @param id - Account ID
|
||||
* @returns List of model IDs returned by the upstream
|
||||
*/
|
||||
export async function syncUpstreamModels(id: number): Promise<SyncUpstreamModelsResult> {
|
||||
const { data } = await apiClient.post<SyncUpstreamModelsResult>(`/admin/accounts/${id}/models/sync-upstream`)
|
||||
return data
|
||||
}
|
||||
|
||||
export interface CRSPreviewAccount {
|
||||
crs_account_id: string
|
||||
kind: string
|
||||
@ -660,6 +674,7 @@ export const accountsAPI = {
|
||||
resetTempUnschedulable,
|
||||
setSchedulable,
|
||||
getAvailableModels,
|
||||
syncUpstreamModels,
|
||||
generateAuthUrl,
|
||||
exchangeCode,
|
||||
refreshOpenAIToken,
|
||||
|
||||
@ -139,7 +139,7 @@
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" />
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" :account-id="account?.id" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{
|
||||
@ -454,7 +454,7 @@
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" />
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" :account-id="account?.id" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{
|
||||
@ -666,7 +666,7 @@
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" />
|
||||
<ModelWhitelistSelector v-model="allowedModels" :platform="account?.platform || 'anthropic'" :account-id="account?.id" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{
|
||||
@ -987,6 +987,17 @@
|
||||
<p class="text-xs text-purple-700 dark:text-purple-400">{{ t('admin.accounts.mapRequestModels') }}</p>
|
||||
</div>
|
||||
|
||||
<div class="mb-3 flex flex-wrap gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="syncAntigravityUpstreamModels"
|
||||
:disabled="isSyncingAntigravityUpstream || !account?.id"
|
||||
class="rounded-lg border border-emerald-200 px-3 py-1.5 text-sm text-emerald-600 hover:bg-emerald-50 disabled:cursor-not-allowed disabled:opacity-60 dark:border-emerald-800 dark:text-emerald-400 dark:hover:bg-emerald-900/30"
|
||||
>
|
||||
{{ isSyncingAntigravityUpstream ? t('admin.accounts.syncUpstreamModelsLoading') : t('admin.accounts.syncUpstreamModels') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in antigravityModelMappings"
|
||||
@ -2288,6 +2299,7 @@ const allowOverages = ref(false) // For antigravity accounts: enable AI Credits
|
||||
const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const antigravityWhitelistModels = ref<string[]>([])
|
||||
const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||
const isSyncingAntigravityUpstream = ref(false)
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-model-mapping')
|
||||
@ -2935,6 +2947,40 @@ const addAntigravityPresetMapping = (from: string, to: string) => {
|
||||
antigravityModelMappings.value.push({ from, to })
|
||||
}
|
||||
|
||||
const syncAntigravityUpstreamModels = async () => {
|
||||
if (!props.account?.id || isSyncingAntigravityUpstream.value) return
|
||||
|
||||
isSyncingAntigravityUpstream.value = true
|
||||
try {
|
||||
const result = await adminAPI.accounts.syncUpstreamModels(props.account.id)
|
||||
const upstreamModels = result.models.map((model) => model.trim()).filter(Boolean)
|
||||
if (upstreamModels.length === 0) {
|
||||
appStore.showInfo(t('admin.accounts.syncUpstreamModelsEmpty'))
|
||||
return
|
||||
}
|
||||
|
||||
let addedCount = 0
|
||||
for (const model of upstreamModels) {
|
||||
const exists = antigravityModelMappings.value.some((mapping) => mapping.from === model)
|
||||
if (!exists) {
|
||||
antigravityModelMappings.value.push({ from: model, to: model })
|
||||
addedCount += 1
|
||||
}
|
||||
}
|
||||
|
||||
if (addedCount > 0) {
|
||||
appStore.showSuccess(t('admin.accounts.syncUpstreamModelsSuccess', { count: addedCount, total: upstreamModels.length }))
|
||||
} else {
|
||||
appStore.showInfo(t('admin.accounts.syncUpstreamModelsNoChanges', { count: upstreamModels.length }))
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : t('admin.accounts.syncUpstreamModelsFailed')
|
||||
appStore.showError(t('admin.accounts.syncUpstreamModelsError', { message }))
|
||||
} finally {
|
||||
isSyncingAntigravityUpstream.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Error code toggle helper
|
||||
const toggleErrorCode = (code: number) => {
|
||||
const index = selectedErrorCodes.value.indexOf(code)
|
||||
|
||||
@ -85,6 +85,15 @@
|
||||
>
|
||||
{{ t('admin.accounts.fillRelatedModels') }}
|
||||
</button>
|
||||
<button
|
||||
v-if="canSyncUpstream"
|
||||
type="button"
|
||||
@click="syncUpstreamModels"
|
||||
:disabled="isSyncingUpstream"
|
||||
class="rounded-lg border border-emerald-200 px-3 py-1.5 text-sm text-emerald-600 hover:bg-emerald-50 disabled:cursor-not-allowed disabled:opacity-60 dark:border-emerald-800 dark:text-emerald-400 dark:hover:bg-emerald-900/30"
|
||||
>
|
||||
{{ isSyncingUpstream ? t('admin.accounts.syncUpstreamModelsLoading') : t('admin.accounts.syncUpstreamModels') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="clearAll"
|
||||
@ -123,6 +132,7 @@
|
||||
import { ref, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { accountsAPI } from '@/api/admin/accounts'
|
||||
import ModelIcon from '@/components/common/ModelIcon.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { allModels, getModelsByPlatform } from '@/composables/useModelWhitelist'
|
||||
@ -133,6 +143,7 @@ const props = defineProps<{
|
||||
modelValue: string[]
|
||||
platform?: string
|
||||
platforms?: string[]
|
||||
accountId?: number
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@ -145,6 +156,7 @@ const showDropdown = ref(false)
|
||||
const searchQuery = ref('')
|
||||
const customModel = ref('')
|
||||
const isComposing = ref(false)
|
||||
const isSyncingUpstream = ref(false)
|
||||
const normalizedPlatforms = computed(() => {
|
||||
const rawPlatforms =
|
||||
props.platforms && props.platforms.length > 0
|
||||
@ -162,6 +174,13 @@ const normalizedPlatforms = computed(() => {
|
||||
)
|
||||
})
|
||||
|
||||
const upstreamSyncPlatforms = new Set(['anthropic', 'openai', 'gemini', 'antigravity'])
|
||||
const canSyncUpstream = computed(() => {
|
||||
if (!props.accountId) return false
|
||||
if (normalizedPlatforms.value.length === 0) return true
|
||||
return normalizedPlatforms.value.some(platform => upstreamSyncPlatforms.has(platform.toLowerCase()))
|
||||
})
|
||||
|
||||
const availableOptions = computed(() => {
|
||||
if (normalizedPlatforms.value.length === 0) {
|
||||
return allModels
|
||||
@ -229,6 +248,41 @@ const fillRelated = () => {
|
||||
emit('update:modelValue', newModels)
|
||||
}
|
||||
|
||||
const syncUpstreamModels = async () => {
|
||||
if (!props.accountId || isSyncingUpstream.value) return
|
||||
|
||||
isSyncingUpstream.value = true
|
||||
try {
|
||||
const result = await accountsAPI.syncUpstreamModels(props.accountId)
|
||||
const upstreamModels = result.models.map(model => model.trim()).filter(Boolean)
|
||||
if (upstreamModels.length === 0) {
|
||||
appStore.showInfo(t('admin.accounts.syncUpstreamModelsEmpty'))
|
||||
return
|
||||
}
|
||||
|
||||
const newModels = [...props.modelValue]
|
||||
let addedCount = 0
|
||||
for (const model of upstreamModels) {
|
||||
if (!newModels.includes(model)) {
|
||||
newModels.push(model)
|
||||
addedCount += 1
|
||||
}
|
||||
}
|
||||
|
||||
emit('update:modelValue', newModels)
|
||||
if (addedCount > 0) {
|
||||
appStore.showSuccess(t('admin.accounts.syncUpstreamModelsSuccess', { count: addedCount, total: upstreamModels.length }))
|
||||
} else {
|
||||
appStore.showInfo(t('admin.accounts.syncUpstreamModelsNoChanges', { count: upstreamModels.length }))
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : t('admin.accounts.syncUpstreamModelsFailed')
|
||||
appStore.showError(t('admin.accounts.syncUpstreamModelsError', { message }))
|
||||
} finally {
|
||||
isSyncingUpstream.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const clearAll = () => {
|
||||
emit('update:modelValue', [])
|
||||
}
|
||||
|
||||
@ -3229,6 +3229,13 @@ export default {
|
||||
searchModels: 'Search models...',
|
||||
noMatchingModels: 'No matching models',
|
||||
fillRelatedModels: 'Sync latest supported models',
|
||||
syncUpstreamModels: 'Sync upstream supported models',
|
||||
syncUpstreamModelsLoading: 'Syncing upstream...',
|
||||
syncUpstreamModelsSuccess: 'Synced {count} new model(s) from upstream ({total} upstream total)',
|
||||
syncUpstreamModelsNoChanges: 'All {count} upstream model(s) are already in the whitelist',
|
||||
syncUpstreamModelsEmpty: 'Upstream returned no models to sync',
|
||||
syncUpstreamModelsFailed: 'Failed to sync upstream models',
|
||||
syncUpstreamModelsError: 'Failed to sync upstream models: {message}',
|
||||
clearAllModels: 'Clear all models',
|
||||
customModelName: 'Custom model name',
|
||||
enterCustomModelName: 'Enter custom model name',
|
||||
|
||||
@ -3373,6 +3373,13 @@ export default {
|
||||
searchModels: '搜索模型...',
|
||||
noMatchingModels: '没有匹配的模型',
|
||||
fillRelatedModels: '同步最新支持模型',
|
||||
syncUpstreamModels: '同步上游支持的模型',
|
||||
syncUpstreamModelsLoading: '同步上游中...',
|
||||
syncUpstreamModelsSuccess: '已从上游同步 {count} 个新模型(上游共 {total} 个)',
|
||||
syncUpstreamModelsNoChanges: '上游 {count} 个模型均已在白名单中',
|
||||
syncUpstreamModelsEmpty: '上游没有返回可同步的模型',
|
||||
syncUpstreamModelsFailed: '同步上游模型失败',
|
||||
syncUpstreamModelsError: '同步上游模型失败:{message}',
|
||||
clearAllModels: '清除所有模型',
|
||||
customModelName: '自定义模型名称',
|
||||
enterCustomModelName: '输入自定义模型名称',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user