Merge pull request #2582 from wucm667/feat/channel-pricing-sync-models

feat(channels): 模型定价支持一键同步最新模型
This commit is contained in:
Wesley Liddick 2026-05-20 08:43:10 +08:00 committed by GitHub
commit 74e35a0150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 252 additions and 12 deletions

View File

@ -225,7 +225,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
channelHandler := admin.NewChannelHandler(channelService, billingService, pricingService)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)

View File

@ -17,11 +17,12 @@ import (
type ChannelHandler struct {
channelService *service.ChannelService
billingService *service.BillingService
pricingService *service.PricingService
}
// NewChannelHandler creates a new admin channel handler
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
return &ChannelHandler{channelService: channelService, billingService: billingService}
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService, pricingService *service.PricingService) *ChannelHandler {
return &ChannelHandler{channelService: channelService, billingService: billingService, pricingService: pricingService}
}
// --- Request / Response types ---
@ -500,3 +501,34 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
"image_output_price": pricing.ImageOutputPricePerToken,
})
}
// platformToLiteLLMProvider maps a channel platform name to the corresponding
// LiteLLM provider string used as the key in the pricing catalog.
var platformToLiteLLMProvider = map[string]string{
service.PlatformAnthropic: "anthropic",
service.PlatformOpenAI: "openai",
service.PlatformGemini: "google",
service.PlatformAntigravity: "anthropic",
}
// SyncPricingModels 返回 LiteLLM 定价目录中指定平台的最新模型列表
// GET /api/v1/admin/channels/pricing/sync-models?platform=anthropic
func (h *ChannelHandler) SyncPricingModels(c *gin.Context) {
platform := strings.ToLower(strings.TrimSpace(c.Query("platform")))
if platform == "" {
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "platform parameter is required").
WithMetadata(map[string]string{"param": "platform"}))
return
}
provider, ok := platformToLiteLLMProvider[platform]
if !ok {
response.ErrorFrom(c, infraerrors.BadRequest("UNSUPPORTED_PLATFORM",
fmt.Sprintf("unsupported platform: %s", platform)).
WithMetadata(map[string]string{"param": "platform"}))
return
}
models := h.pricingService.ListModelNamesByProvider(provider)
response.Success(c, gin.H{"models": models})
}

View File

@ -3,10 +3,14 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@ -416,3 +420,58 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. SyncPricingModels handler
// ---------------------------------------------------------------------------
func setupSyncPricingModelsRouter(pricingSvc *service.PricingService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := &ChannelHandler{pricingService: pricingSvc}
router.GET("/channels/pricing/sync-models", h.SyncPricingModels)
return router
}
func TestSyncPricingModels_MissingPlatform(t *testing.T) {
svc := service.NewPricingService(nil, nil)
router := setupSyncPricingModelsRouter(svc)
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestSyncPricingModels_UnsupportedPlatform(t *testing.T) {
svc := service.NewPricingService(nil, nil)
router := setupSyncPricingModelsRouter(svc)
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform=unknown", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestSyncPricingModels_ValidPlatform_EmptyService(t *testing.T) {
svc := service.NewPricingService(nil, nil)
router := setupSyncPricingModelsRouter(svc)
for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} {
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform="+platform, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "platform=%s", platform)
var body struct {
Data struct {
Models []string `json:"models"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.NotNil(t, body.Data.Models, "models must not be null for platform=%s", platform)
}
}

View File

@ -585,6 +585,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/pricing/sync-models", h.Admin.Channel.SyncPricingModels)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)

View File

@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
@ -903,6 +904,24 @@ func (s *PricingService) getHashFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
}
// ListModelNamesByProvider returns all model names in the catalog whose
// LiteLLMProvider matches the given provider string (case-insensitive).
// The returned slice is sorted alphabetically.
func (s *PricingService) ListModelNamesByProvider(provider string) []string {
s.mu.RLock()
defer s.mu.RUnlock()
provider = strings.ToLower(strings.TrimSpace(provider))
names := make([]string, 0)
for name, p := range s.pricingData {
if strings.ToLower(p.LiteLLMProvider) == provider {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
// isNumeric 检查字符串是否为纯数字
func isNumeric(s string) bool {
for _, c := range s {

View File

@ -234,3 +234,60 @@ func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) {
require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12)
require.True(t, pricing.SupportsServiceTier)
}
// ---------------------------------------------------------------------------
// ListModelNamesByProvider
// ---------------------------------------------------------------------------
func TestListModelNamesByProvider_ReturnsMatchingModels(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"claude-opus-4-5-20251101": {LiteLLMProvider: "anthropic", InputCostPerToken: 1.5e-5},
"claude-sonnet-4-5": {LiteLLMProvider: "anthropic", InputCostPerToken: 3e-6},
"gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6},
"gemini-2.5-pro": {LiteLLMProvider: "google", InputCostPerToken: 1.25e-6},
},
}
got := svc.ListModelNamesByProvider("anthropic")
require.ElementsMatch(t, []string{"claude-opus-4-5-20251101", "claude-sonnet-4-5"}, got)
// Must be sorted
require.Equal(t, "claude-opus-4-5-20251101", got[0])
require.Equal(t, "claude-sonnet-4-5", got[1])
}
func TestListModelNamesByProvider_CaseInsensitive(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"gpt-4o": {LiteLLMProvider: "OpenAI", InputCostPerToken: 5e-6},
},
}
got := svc.ListModelNamesByProvider("openai")
require.Equal(t, []string{"gpt-4o"}, got)
got2 := svc.ListModelNamesByProvider("OPENAI")
require.Equal(t, []string{"gpt-4o"}, got2)
}
func TestListModelNamesByProvider_NoMatch(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6},
},
}
got := svc.ListModelNamesByProvider("anthropic")
require.NotNil(t, got)
require.Empty(t, got)
}
func TestListModelNamesByProvider_EmptyCatalog(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{},
}
got := svc.ListModelNamesByProvider("openai")
require.NotNil(t, got)
require.Empty(t, got)
}

View File

@ -164,5 +164,19 @@ export async function getModelDefaultPricing(model: string): Promise<ModelDefaul
return data
}
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
export interface SyncPricingModelsResult {
models: string[]
}
/**
* Fetch the latest model names from the LiteLLM pricing catalog for the given platform
*/
export async function syncPricingModels(platform: string): Promise<SyncPricingModelsResult> {
const { data } = await apiClient.get<SyncPricingModelsResult>('/admin/channels/pricing/sync-models', {
params: { platform }
})
return data
}
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, syncPricingModels }
export default channelsAPI

View File

@ -2370,10 +2370,15 @@ export default {
searchAccountPlaceholder: 'Search accounts...',
ruleAccountsHint: 'Leave empty to match all accounts',
ruleModelPricing: 'Model Pricing',
noGroupsInChannel: 'No groups selected in platform tabs above',
unnamed: 'Unnamed'
}
},
noGroupsInChannel: 'No groups selected in platform tabs above',
unnamed: 'Unnamed',
syncLatestModels: 'Sync Latest Models',
syncingModels: 'Syncing...',
syncModelsSuccess: 'Synced {count} new model(s)',
syncModelsAlreadyUpToDate: 'Models already up to date',
syncModelsError: 'Failed to sync models'
}
},
riskControl: {
title: 'Risk Control',

View File

@ -2448,7 +2448,12 @@ export default {
ruleAccountsHint: '留空表示匹配所有账号',
ruleModelPricing: '模型定价',
noGroupsInChannel: '上方平台标签页中未选择分组',
unnamed: '未命名'
unnamed: '未命名',
syncLatestModels: '同步最新模型',
syncingModels: '同步中...',
syncModelsSuccess: '已同步 {count} 个新模型',
syncModelsAlreadyUpToDate: '模型列表已是最新',
syncModelsError: '同步模型失败'
}
},

View File

@ -406,9 +406,19 @@
<div>
<div class="mb-1 flex items-center justify-between">
<label class="input-label text-xs mb-0">{{ t('admin.channels.form.modelPricing', 'Model Pricing') }}</label>
<button type="button" @click="addPricingEntry(sIdx)" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('common.add', 'Add') }}
</button>
<div class="flex items-center gap-2">
<button
type="button"
@click="syncLatestModels(sIdx)"
:disabled="syncingPlatform === section.platform"
class="text-xs text-gray-500 hover:text-primary-600 disabled:opacity-50"
>
{{ syncingPlatform === section.platform ? t('admin.channels.form.syncingModels') : t('admin.channels.form.syncLatestModels') }}
</button>
<button type="button" @click="addPricingEntry(sIdx)" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('common.add', 'Add') }}
</button>
</div>
</div>
<div
v-if="section.model_pricing.length === 0"
@ -834,6 +844,44 @@ function addPricingEntry(sectionIdx: number) {
})
}
const syncingPlatform = ref<string | null>(null)
async function syncLatestModels(sectionIdx: number) {
const platform = form.platforms[sectionIdx].platform
if (syncingPlatform.value) return
syncingPlatform.value = platform
try {
const result = await adminAPI.channels.syncPricingModels(platform)
// Collect all model names already present in this platform's pricing entries
const existingModels = new Set<string>()
for (const entry of form.platforms[sectionIdx].model_pricing) {
for (const m of entry.models) existingModels.add(m)
}
const newModels = result.models.filter(m => !existingModels.has(m))
if (newModels.length === 0) {
appStore.showSuccess(t('admin.channels.form.syncModelsAlreadyUpToDate'))
return
}
// Add new models as a single new pricing entry (user fills in prices)
form.platforms[sectionIdx].model_pricing.push({
models: newModels,
billing_mode: 'token',
input_price: null,
output_price: null,
cache_write_price: null,
cache_read_price: null,
image_output_price: null,
per_request_price: null,
intervals: []
})
appStore.showSuccess(t('admin.channels.form.syncModelsSuccess', { count: newModels.length }))
} catch (error) {
appStore.showError(extractApiErrorMessage(error, t('admin.channels.form.syncModelsError')))
} finally {
syncingPlatform.value = null
}
}
function updatePricingEntry(sectionIdx: number, idx: number, updated: PricingFormEntry) {
form.platforms[sectionIdx].model_pricing.splice(idx, 1, updated)
}