diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 5a190c33..1e5625a6 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -225,7 +225,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - channelHandler := admin.NewChannelHandler(channelService, billingService) + channelHandler := admin.NewChannelHandler(channelService, billingService, pricingService) channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db) channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 950e6e72..bf547346 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -17,11 +17,12 @@ import ( type ChannelHandler struct { channelService *service.ChannelService billingService *service.BillingService + pricingService *service.PricingService } // NewChannelHandler creates a new admin channel handler -func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler { - return &ChannelHandler{channelService: channelService, billingService: billingService} +func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService, pricingService *service.PricingService) *ChannelHandler { + return &ChannelHandler{channelService: channelService, billingService: billingService, pricingService: pricingService} } // --- Request / Response types --- @@ -500,3 +501,34 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) { "image_output_price": pricing.ImageOutputPricePerToken, }) } + +// platformToLiteLLMProvider maps a channel platform name to the corresponding +// LiteLLM provider string used as the key in the pricing catalog. +var platformToLiteLLMProvider = map[string]string{ + service.PlatformAnthropic: "anthropic", + service.PlatformOpenAI: "openai", + service.PlatformGemini: "google", + service.PlatformAntigravity: "anthropic", +} + +// SyncPricingModels 返回 LiteLLM 定价目录中指定平台的最新模型列表 +// GET /api/v1/admin/channels/pricing/sync-models?platform=anthropic +func (h *ChannelHandler) SyncPricingModels(c *gin.Context) { + platform := strings.ToLower(strings.TrimSpace(c.Query("platform"))) + if platform == "" { + response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "platform parameter is required"). + WithMetadata(map[string]string{"param": "platform"})) + return + } + + provider, ok := platformToLiteLLMProvider[platform] + if !ok { + response.ErrorFrom(c, infraerrors.BadRequest("UNSUPPORTED_PLATFORM", + fmt.Sprintf("unsupported platform: %s", platform)). + WithMetadata(map[string]string{"param": "platform"})) + return + } + + models := h.pricingService.ListModelNamesByProvider(provider) + response.Success(c, gin.H{"models": models}) +} diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index 12cd4bdd..d05a1a6a 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -3,10 +3,14 @@ package admin import ( + "encoding/json" + "net/http" + "net/http/httptest" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -416,3 +420,58 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) { require.Nil(t, r.ImageOutputPrice) require.Nil(t, r.PerRequestPrice) } + +// --------------------------------------------------------------------------- +// 3. SyncPricingModels handler +// --------------------------------------------------------------------------- + +func setupSyncPricingModelsRouter(pricingSvc *service.PricingService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + h := &ChannelHandler{pricingService: pricingSvc} + router.GET("/channels/pricing/sync-models", h.SyncPricingModels) + return router +} + +func TestSyncPricingModels_MissingPlatform(t *testing.T) { + svc := service.NewPricingService(nil, nil) + router := setupSyncPricingModelsRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSyncPricingModels_UnsupportedPlatform(t *testing.T) { + svc := service.NewPricingService(nil, nil) + router := setupSyncPricingModelsRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform=unknown", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestSyncPricingModels_ValidPlatform_EmptyService(t *testing.T) { + svc := service.NewPricingService(nil, nil) + router := setupSyncPricingModelsRouter(svc) + + for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} { + req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform="+platform, nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "platform=%s", platform) + + var body struct { + Data struct { + Models []string `json:"models"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.NotNil(t, body.Data.Models, "models must not be null for platform=%s", platform) + } +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 92e2f5b6..1be2ce1e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -590,6 +590,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { channels.GET("", h.Admin.Channel.List) channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing) + channels.GET("/pricing/sync-models", h.Admin.Channel.SyncPricingModels) channels.GET("/:id", h.Admin.Channel.GetByID) channels.POST("", h.Admin.Channel.Create) channels.PUT("/:id", h.Admin.Channel.Update) diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 8a033710..bd0c30df 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "regexp" + "sort" "strings" "sync" "time" @@ -903,6 +904,24 @@ func (s *PricingService) getHashFilePath() string { return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256") } +// ListModelNamesByProvider returns all model names in the catalog whose +// LiteLLMProvider matches the given provider string (case-insensitive). +// The returned slice is sorted alphabetically. +func (s *PricingService) ListModelNamesByProvider(provider string) []string { + s.mu.RLock() + defer s.mu.RUnlock() + + provider = strings.ToLower(strings.TrimSpace(provider)) + names := make([]string, 0) + for name, p := range s.pricingData { + if strings.ToLower(p.LiteLLMProvider) == provider { + names = append(names, name) + } + } + sort.Strings(names) + return names +} + // isNumeric 检查字符串是否为纯数字 func isNumeric(s string) bool { for _, c := range s { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index d0b46886..cc8b120a 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -234,3 +234,60 @@ func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) { require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12) require.True(t, pricing.SupportsServiceTier) } + +// --------------------------------------------------------------------------- +// ListModelNamesByProvider +// --------------------------------------------------------------------------- + +func TestListModelNamesByProvider_ReturnsMatchingModels(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "claude-opus-4-5-20251101": {LiteLLMProvider: "anthropic", InputCostPerToken: 1.5e-5}, + "claude-sonnet-4-5": {LiteLLMProvider: "anthropic", InputCostPerToken: 3e-6}, + "gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6}, + "gemini-2.5-pro": {LiteLLMProvider: "google", InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.ListModelNamesByProvider("anthropic") + require.ElementsMatch(t, []string{"claude-opus-4-5-20251101", "claude-sonnet-4-5"}, got) + // Must be sorted + require.Equal(t, "claude-opus-4-5-20251101", got[0]) + require.Equal(t, "claude-sonnet-4-5", got[1]) +} + +func TestListModelNamesByProvider_CaseInsensitive(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-4o": {LiteLLMProvider: "OpenAI", InputCostPerToken: 5e-6}, + }, + } + + got := svc.ListModelNamesByProvider("openai") + require.Equal(t, []string{"gpt-4o"}, got) + + got2 := svc.ListModelNamesByProvider("OPENAI") + require.Equal(t, []string{"gpt-4o"}, got2) +} + +func TestListModelNamesByProvider_NoMatch(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6}, + }, + } + + got := svc.ListModelNamesByProvider("anthropic") + require.NotNil(t, got) + require.Empty(t, got) +} + +func TestListModelNamesByProvider_EmptyCatalog(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{}, + } + + got := svc.ListModelNamesByProvider("openai") + require.NotNil(t, got) + require.Empty(t, got) +} diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index 9d430134..afa43a2d 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -164,5 +164,19 @@ export async function getModelDefaultPricing(model: string): Promise { + const { data } = await apiClient.get('/admin/channels/pricing/sync-models', { + params: { platform } + }) + return data +} + +const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, syncPricingModels } export default channelsAPI diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 01c8b82e..20e0e776 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2370,10 +2370,15 @@ export default { searchAccountPlaceholder: 'Search accounts...', ruleAccountsHint: 'Leave empty to match all accounts', ruleModelPricing: 'Model Pricing', - noGroupsInChannel: 'No groups selected in platform tabs above', - unnamed: 'Unnamed' - } - }, + noGroupsInChannel: 'No groups selected in platform tabs above', + unnamed: 'Unnamed', + syncLatestModels: 'Sync Latest Models', + syncingModels: 'Syncing...', + syncModelsSuccess: 'Synced {count} new model(s)', + syncModelsAlreadyUpToDate: 'Models already up to date', + syncModelsError: 'Failed to sync models' + } + }, riskControl: { title: 'Risk Control', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 0c62f18d..b10dbf3d 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2448,7 +2448,12 @@ export default { ruleAccountsHint: '留空表示匹配所有账号', ruleModelPricing: '模型定价', noGroupsInChannel: '上方平台标签页中未选择分组', - unnamed: '未命名' + unnamed: '未命名', + syncLatestModels: '同步最新模型', + syncingModels: '同步中...', + syncModelsSuccess: '已同步 {count} 个新模型', + syncModelsAlreadyUpToDate: '模型列表已是最新', + syncModelsError: '同步模型失败' } }, diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index b2d6d8e6..3a4aeb0d 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -406,9 +406,19 @@
- +
+ + +
(null) + +async function syncLatestModels(sectionIdx: number) { + const platform = form.platforms[sectionIdx].platform + if (syncingPlatform.value) return + syncingPlatform.value = platform + try { + const result = await adminAPI.channels.syncPricingModels(platform) + // Collect all model names already present in this platform's pricing entries + const existingModels = new Set() + for (const entry of form.platforms[sectionIdx].model_pricing) { + for (const m of entry.models) existingModels.add(m) + } + const newModels = result.models.filter(m => !existingModels.has(m)) + if (newModels.length === 0) { + appStore.showSuccess(t('admin.channels.form.syncModelsAlreadyUpToDate')) + return + } + // Add new models as a single new pricing entry (user fills in prices) + form.platforms[sectionIdx].model_pricing.push({ + models: newModels, + billing_mode: 'token', + input_price: null, + output_price: null, + cache_write_price: null, + cache_read_price: null, + image_output_price: null, + per_request_price: null, + intervals: [] + }) + appStore.showSuccess(t('admin.channels.form.syncModelsSuccess', { count: newModels.length })) + } catch (error) { + appStore.showError(extractApiErrorMessage(error, t('admin.channels.form.syncModelsError'))) + } finally { + syncingPlatform.value = null + } +} + function updatePricingEntry(sectionIdx: number, idx: number, updated: PricingFormEntry) { form.platforms[sectionIdx].model_pricing.splice(idx, 1, updated) }