feat(channels): 模型定价支持一键同步最新模型
从 LiteLLM 定价目录中读取指定平台的最新模型列表, 将尚未录入的模型以新定价条目(价格留空)的形式追加, 管理员只需点击同步最新模型按钮即可完成操作。 - backend/service: PricingService 新增 ListModelNamesByProvider - backend/handler: ChannelHandler 新增 SyncPricingModels (GET /api/v1/admin/channels/pricing/sync-models) - backend/routes: 注册新路由(在 /:id 通配符之前) - backend/wire_gen: 手动更新 NewChannelHandler 调用 - frontend/api: channels.ts 新增 syncPricingModels - frontend/i18n: zh.ts / en.ts 新增 5 个 key - frontend/view: ChannelsView 定价区域标题行新增「同步最新模型」按钮 - tests: pricing_service_test + channel_handler_test 新增单元测试
This commit is contained in:
parent
8927ab091e
commit
92ad68a314
@ -225,7 +225,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService, pricingService)
|
||||
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
||||
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
|
||||
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
||||
|
||||
@ -17,11 +17,12 @@ import (
|
||||
type ChannelHandler struct {
|
||||
channelService *service.ChannelService
|
||||
billingService *service.BillingService
|
||||
pricingService *service.PricingService
|
||||
}
|
||||
|
||||
// NewChannelHandler creates a new admin channel handler
|
||||
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService, billingService: billingService}
|
||||
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService, pricingService *service.PricingService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService, billingService: billingService, pricingService: pricingService}
|
||||
}
|
||||
|
||||
// --- Request / Response types ---
|
||||
@ -500,3 +501,34 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
|
||||
"image_output_price": pricing.ImageOutputPricePerToken,
|
||||
})
|
||||
}
|
||||
|
||||
// platformToLiteLLMProvider maps a channel platform name to the corresponding
|
||||
// LiteLLM provider string used as the key in the pricing catalog.
|
||||
var platformToLiteLLMProvider = map[string]string{
|
||||
service.PlatformAnthropic: "anthropic",
|
||||
service.PlatformOpenAI: "openai",
|
||||
service.PlatformGemini: "google",
|
||||
service.PlatformAntigravity: "anthropic",
|
||||
}
|
||||
|
||||
// SyncPricingModels 返回 LiteLLM 定价目录中指定平台的最新模型列表
|
||||
// GET /api/v1/admin/channels/pricing/sync-models?platform=anthropic
|
||||
func (h *ChannelHandler) SyncPricingModels(c *gin.Context) {
|
||||
platform := strings.ToLower(strings.TrimSpace(c.Query("platform")))
|
||||
if platform == "" {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "platform parameter is required").
|
||||
WithMetadata(map[string]string{"param": "platform"}))
|
||||
return
|
||||
}
|
||||
|
||||
provider, ok := platformToLiteLLMProvider[platform]
|
||||
if !ok {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("UNSUPPORTED_PLATFORM",
|
||||
fmt.Sprintf("unsupported platform: %s", platform)).
|
||||
WithMetadata(map[string]string{"param": "platform"}))
|
||||
return
|
||||
}
|
||||
|
||||
models := h.pricingService.ListModelNamesByProvider(provider)
|
||||
response.Success(c, gin.H{"models": models})
|
||||
}
|
||||
|
||||
@ -3,10 +3,14 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -416,3 +420,58 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
|
||||
require.Nil(t, r.ImageOutputPrice)
|
||||
require.Nil(t, r.PerRequestPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. SyncPricingModels handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func setupSyncPricingModelsRouter(pricingSvc *service.PricingService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
h := &ChannelHandler{pricingService: pricingSvc}
|
||||
router.GET("/channels/pricing/sync-models", h.SyncPricingModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestSyncPricingModels_MissingPlatform(t *testing.T) {
|
||||
svc := service.NewPricingService(nil, nil)
|
||||
router := setupSyncPricingModelsRouter(svc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestSyncPricingModels_UnsupportedPlatform(t *testing.T) {
|
||||
svc := service.NewPricingService(nil, nil)
|
||||
router := setupSyncPricingModelsRouter(svc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform=unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestSyncPricingModels_ValidPlatform_EmptyService(t *testing.T) {
|
||||
svc := service.NewPricingService(nil, nil)
|
||||
router := setupSyncPricingModelsRouter(svc)
|
||||
|
||||
for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} {
|
||||
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform="+platform, nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "platform=%s", platform)
|
||||
|
||||
var body struct {
|
||||
Data struct {
|
||||
Models []string `json:"models"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.NotNil(t, body.Data.Models, "models must not be null for platform=%s", platform)
|
||||
}
|
||||
}
|
||||
|
||||
@ -590,6 +590,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
channels.GET("", h.Admin.Channel.List)
|
||||
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
|
||||
channels.GET("/pricing/sync-models", h.Admin.Channel.SyncPricingModels)
|
||||
channels.GET("/:id", h.Admin.Channel.GetByID)
|
||||
channels.POST("", h.Admin.Channel.Create)
|
||||
channels.PUT("/:id", h.Admin.Channel.Update)
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -903,6 +904,24 @@ func (s *PricingService) getHashFilePath() string {
|
||||
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
|
||||
}
|
||||
|
||||
// ListModelNamesByProvider returns all model names in the catalog whose
|
||||
// LiteLLMProvider matches the given provider string (case-insensitive).
|
||||
// The returned slice is sorted alphabetically.
|
||||
func (s *PricingService) ListModelNamesByProvider(provider string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
names := make([]string, 0)
|
||||
for name, p := range s.pricingData {
|
||||
if strings.ToLower(p.LiteLLMProvider) == provider {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// isNumeric 检查字符串是否为纯数字
|
||||
func isNumeric(s string) bool {
|
||||
for _, c := range s {
|
||||
|
||||
@ -234,3 +234,60 @@ func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) {
|
||||
require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12)
|
||||
require.True(t, pricing.SupportsServiceTier)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ListModelNamesByProvider
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestListModelNamesByProvider_ReturnsMatchingModels(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"claude-opus-4-5-20251101": {LiteLLMProvider: "anthropic", InputCostPerToken: 1.5e-5},
|
||||
"claude-sonnet-4-5": {LiteLLMProvider: "anthropic", InputCostPerToken: 3e-6},
|
||||
"gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6},
|
||||
"gemini-2.5-pro": {LiteLLMProvider: "google", InputCostPerToken: 1.25e-6},
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.ListModelNamesByProvider("anthropic")
|
||||
require.ElementsMatch(t, []string{"claude-opus-4-5-20251101", "claude-sonnet-4-5"}, got)
|
||||
// Must be sorted
|
||||
require.Equal(t, "claude-opus-4-5-20251101", got[0])
|
||||
require.Equal(t, "claude-sonnet-4-5", got[1])
|
||||
}
|
||||
|
||||
func TestListModelNamesByProvider_CaseInsensitive(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-4o": {LiteLLMProvider: "OpenAI", InputCostPerToken: 5e-6},
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.ListModelNamesByProvider("openai")
|
||||
require.Equal(t, []string{"gpt-4o"}, got)
|
||||
|
||||
got2 := svc.ListModelNamesByProvider("OPENAI")
|
||||
require.Equal(t, []string{"gpt-4o"}, got2)
|
||||
}
|
||||
|
||||
func TestListModelNamesByProvider_NoMatch(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6},
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.ListModelNamesByProvider("anthropic")
|
||||
require.NotNil(t, got)
|
||||
require.Empty(t, got)
|
||||
}
|
||||
|
||||
func TestListModelNamesByProvider_EmptyCatalog(t *testing.T) {
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{},
|
||||
}
|
||||
|
||||
got := svc.ListModelNamesByProvider("openai")
|
||||
require.NotNil(t, got)
|
||||
require.Empty(t, got)
|
||||
}
|
||||
|
||||
@ -164,5 +164,19 @@ export async function getModelDefaultPricing(model: string): Promise<ModelDefaul
|
||||
return data
|
||||
}
|
||||
|
||||
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
|
||||
export interface SyncPricingModelsResult {
|
||||
models: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the latest model names from the LiteLLM pricing catalog for the given platform
|
||||
*/
|
||||
export async function syncPricingModels(platform: string): Promise<SyncPricingModelsResult> {
|
||||
const { data } = await apiClient.get<SyncPricingModelsResult>('/admin/channels/pricing/sync-models', {
|
||||
params: { platform }
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, syncPricingModels }
|
||||
export default channelsAPI
|
||||
|
||||
@ -2370,10 +2370,15 @@ export default {
|
||||
searchAccountPlaceholder: 'Search accounts...',
|
||||
ruleAccountsHint: 'Leave empty to match all accounts',
|
||||
ruleModelPricing: 'Model Pricing',
|
||||
noGroupsInChannel: 'No groups selected in platform tabs above',
|
||||
unnamed: 'Unnamed'
|
||||
}
|
||||
},
|
||||
noGroupsInChannel: 'No groups selected in platform tabs above',
|
||||
unnamed: 'Unnamed',
|
||||
syncLatestModels: 'Sync Latest Models',
|
||||
syncingModels: 'Syncing...',
|
||||
syncModelsSuccess: 'Synced {count} new model(s)',
|
||||
syncModelsAlreadyUpToDate: 'Models already up to date',
|
||||
syncModelsError: 'Failed to sync models'
|
||||
}
|
||||
},
|
||||
|
||||
riskControl: {
|
||||
title: 'Risk Control',
|
||||
|
||||
@ -2448,7 +2448,12 @@ export default {
|
||||
ruleAccountsHint: '留空表示匹配所有账号',
|
||||
ruleModelPricing: '模型定价',
|
||||
noGroupsInChannel: '上方平台标签页中未选择分组',
|
||||
unnamed: '未命名'
|
||||
unnamed: '未命名',
|
||||
syncLatestModels: '同步最新模型',
|
||||
syncingModels: '同步中...',
|
||||
syncModelsSuccess: '已同步 {count} 个新模型',
|
||||
syncModelsAlreadyUpToDate: '模型列表已是最新',
|
||||
syncModelsError: '同步模型失败'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@ -406,9 +406,19 @@
|
||||
<div>
|
||||
<div class="mb-1 flex items-center justify-between">
|
||||
<label class="input-label text-xs mb-0">{{ t('admin.channels.form.modelPricing', 'Model Pricing') }}</label>
|
||||
<button type="button" @click="addPricingEntry(sIdx)" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('common.add', 'Add') }}
|
||||
</button>
|
||||
<div class="flex items-center gap-2">
|
||||
<button
|
||||
type="button"
|
||||
@click="syncLatestModels(sIdx)"
|
||||
:disabled="syncingPlatform === section.platform"
|
||||
class="text-xs text-gray-500 hover:text-primary-600 disabled:opacity-50"
|
||||
>
|
||||
{{ syncingPlatform === section.platform ? t('admin.channels.form.syncingModels') : t('admin.channels.form.syncLatestModels') }}
|
||||
</button>
|
||||
<button type="button" @click="addPricingEntry(sIdx)" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('common.add', 'Add') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="section.model_pricing.length === 0"
|
||||
@ -834,6 +844,44 @@ function addPricingEntry(sectionIdx: number) {
|
||||
})
|
||||
}
|
||||
|
||||
const syncingPlatform = ref<string | null>(null)
|
||||
|
||||
async function syncLatestModels(sectionIdx: number) {
|
||||
const platform = form.platforms[sectionIdx].platform
|
||||
if (syncingPlatform.value) return
|
||||
syncingPlatform.value = platform
|
||||
try {
|
||||
const result = await adminAPI.channels.syncPricingModels(platform)
|
||||
// Collect all model names already present in this platform's pricing entries
|
||||
const existingModels = new Set<string>()
|
||||
for (const entry of form.platforms[sectionIdx].model_pricing) {
|
||||
for (const m of entry.models) existingModels.add(m)
|
||||
}
|
||||
const newModels = result.models.filter(m => !existingModels.has(m))
|
||||
if (newModels.length === 0) {
|
||||
appStore.showSuccess(t('admin.channels.form.syncModelsAlreadyUpToDate'))
|
||||
return
|
||||
}
|
||||
// Add new models as a single new pricing entry (user fills in prices)
|
||||
form.platforms[sectionIdx].model_pricing.push({
|
||||
models: newModels,
|
||||
billing_mode: 'token',
|
||||
input_price: null,
|
||||
output_price: null,
|
||||
cache_write_price: null,
|
||||
cache_read_price: null,
|
||||
image_output_price: null,
|
||||
per_request_price: null,
|
||||
intervals: []
|
||||
})
|
||||
appStore.showSuccess(t('admin.channels.form.syncModelsSuccess', { count: newModels.length }))
|
||||
} catch (error) {
|
||||
appStore.showError(extractApiErrorMessage(error, t('admin.channels.form.syncModelsError')))
|
||||
} finally {
|
||||
syncingPlatform.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function updatePricingEntry(sectionIdx: number, idx: number, updated: PricingFormEntry) {
|
||||
form.platforms[sectionIdx].model_pricing.splice(idx, 1, updated)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user