feat(channel): 渠道管理系统 — 多模式定价 + 统一计费解析
Cherry-picked from release/custom-0.1.106: a9117600
This commit is contained in:
parent
b384570de3
commit
91c9b8d062
@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
|
||||
emailCache := repository.NewEmailCache(redisClient)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
@ -175,7 +176,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService)
|
||||
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
_ = modelPricingResolver // Phase 4: 已注册,后续 Gateway 迁移时使用
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@ -213,7 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
channelHandler := admin.NewChannelHandler(channelService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
|
||||
308
backend/internal/handler/admin/channel_handler.go
Normal file
308
backend/internal/handler/admin/channel_handler.go
Normal file
@ -0,0 +1,308 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChannelHandler handles admin channel management
|
||||
type ChannelHandler struct {
|
||||
channelService *service.ChannelService
|
||||
}
|
||||
|
||||
// NewChannelHandler creates a new admin channel handler
|
||||
func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService}
|
||||
}
|
||||
|
||||
// --- Request / Response types ---
|
||||
|
||||
type createChannelRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Description string `json:"description"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=100"`
|
||||
Description *string `json:"description"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
Models []string `json:"models" binding:"required,min=1,max=100"`
|
||||
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
||||
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
||||
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
|
||||
Intervals []pricingIntervalRequest `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalRequest struct {
|
||||
MinTokens int `json:"min_tokens"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
TierLabel string `json:"tier_label"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
type channelResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Status string `json:"status"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type channelModelPricingResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
Intervals []pricingIntervalResponse `json:"intervals"`
|
||||
}
|
||||
|
||||
type pricingIntervalResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
MinTokens int `json:"min_tokens"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
TierLabel string `json:"tier_label,omitempty"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
if ch == nil {
|
||||
return nil
|
||||
}
|
||||
resp := &channelResponse{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
if resp.GroupIDs == nil {
|
||||
resp.GroupIDs = []int64{}
|
||||
}
|
||||
|
||||
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
|
||||
for _, p := range ch.ModelPricing {
|
||||
models := p.Models
|
||||
if models == nil {
|
||||
models = []string{}
|
||||
}
|
||||
billingMode := string(p.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = "token"
|
||||
}
|
||||
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
|
||||
for _, iv := range p.Intervals {
|
||||
intervals = append(intervals, pricingIntervalResponse{
|
||||
ID: iv.ID,
|
||||
MinTokens: iv.MinTokens,
|
||||
MaxTokens: iv.MaxTokens,
|
||||
TierLabel: iv.TierLabel,
|
||||
InputPrice: iv.InputPrice,
|
||||
OutputPrice: iv.OutputPrice,
|
||||
CacheWritePrice: iv.CacheWritePrice,
|
||||
CacheReadPrice: iv.CacheReadPrice,
|
||||
PerRequestPrice: iv.PerRequestPrice,
|
||||
SortOrder: iv.SortOrder,
|
||||
})
|
||||
}
|
||||
resp.ModelPricing = append(resp.ModelPricing, channelModelPricingResponse{
|
||||
ID: p.ID,
|
||||
Models: models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: p.InputPrice,
|
||||
OutputPrice: p.OutputPrice,
|
||||
CacheWritePrice: p.CacheWritePrice,
|
||||
CacheReadPrice: p.CacheReadPrice,
|
||||
ImageOutputPrice: p.ImageOutputPrice,
|
||||
Intervals: intervals,
|
||||
})
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
|
||||
result := make([]service.ChannelModelPricing, 0, len(reqs))
|
||||
for _, r := range reqs {
|
||||
billingMode := service.BillingMode(r.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||
for _, iv := range r.Intervals {
|
||||
intervals = append(intervals, service.PricingInterval{
|
||||
MinTokens: iv.MinTokens,
|
||||
MaxTokens: iv.MaxTokens,
|
||||
TierLabel: iv.TierLabel,
|
||||
InputPrice: iv.InputPrice,
|
||||
OutputPrice: iv.OutputPrice,
|
||||
CacheWritePrice: iv.CacheWritePrice,
|
||||
CacheReadPrice: iv.CacheReadPrice,
|
||||
PerRequestPrice: iv.PerRequestPrice,
|
||||
SortOrder: iv.SortOrder,
|
||||
})
|
||||
}
|
||||
result = append(result, service.ChannelModelPricing{
|
||||
Models: r.Models,
|
||||
BillingMode: billingMode,
|
||||
InputPrice: r.InputPrice,
|
||||
OutputPrice: r.OutputPrice,
|
||||
CacheWritePrice: r.CacheWritePrice,
|
||||
CacheReadPrice: r.CacheReadPrice,
|
||||
ImageOutputPrice: r.ImageOutputPrice,
|
||||
Intervals: intervals,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// List handles listing channels with pagination
|
||||
// GET /api/v1/admin/channels
|
||||
func (h *ChannelHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]*channelResponse, 0, len(channels))
|
||||
for i := range channels {
|
||||
out = append(out, channelToResponse(&channels[i]))
|
||||
}
|
||||
response.Paginated(c, out, pag.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a channel by ID
|
||||
// GET /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) GetByID(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Create handles creating a new channel
|
||||
// POST /api/v1/admin/channels
|
||||
func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
var req createChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ModelPricing: pricingRequestToService(req.ModelPricing),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Update handles updating a channel
|
||||
// PUT /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdateChannelInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Status: req.Status,
|
||||
GroupIDs: req.GroupIDs,
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
input.ModelPricing = &pricing
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, channelToResponse(channel))
|
||||
}
|
||||
|
||||
// Delete handles deleting a channel
|
||||
// DELETE /api/v1/admin/channels/:id
|
||||
func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid channel ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Channel deleted successfully"})
|
||||
}
|
||||
@ -30,6 +30,7 @@ type AdminHandlers struct {
|
||||
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
Channel *admin.ChannelHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@ -33,6 +33,7 @@ func ProvideAdminHandlers(
|
||||
tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
channelHandler *admin.ChannelHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@ -59,6 +60,7 @@ func ProvideAdminHandlers(
|
||||
TLSFingerprintProfile: tlsFingerprintProfileHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
Channel: channelHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewTLSFingerprintProfileHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
admin.NewChannelHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
392
backend/internal/repository/channel_repo.go
Normal file
392
backend/internal/repository/channel_repo.go
Normal file
@ -0,0 +1,392 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type channelRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewChannelRepository 创建渠道数据访问实例
|
||||
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
|
||||
return &channelRepository{db: db}
|
||||
}
|
||||
|
||||
// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
|
||||
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
err := tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status) VALUES ($1, $2, $3)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return service.ErrChannelExists
|
||||
}
|
||||
return fmt.Errorf("insert channel: %w", err)
|
||||
}
|
||||
|
||||
// 设置分组关联
|
||||
if len(channel.GroupIDs) > 0 {
|
||||
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型定价
|
||||
if len(channel.ModelPricing) > 0 {
|
||||
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.GroupIDs = groupIDs
|
||||
|
||||
pricing, err := r.ListModelPricing(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.ModelPricing = pricing
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW()
|
||||
WHERE id = $4`,
|
||||
channel.Name, channel.Description, channel.Status, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return service.ErrChannelExists
|
||||
}
|
||||
return fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return service.ErrChannelNotFound
|
||||
}
|
||||
|
||||
// 更新分组关联
|
||||
if channel.GroupIDs != nil {
|
||||
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 更新模型定价
|
||||
if channel.ModelPricing != nil {
|
||||
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete channel: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return service.ErrChannelNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
|
||||
where := []string{"1=1"}
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
|
||||
if status != "" {
|
||||
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
|
||||
args = append(args, status)
|
||||
argIdx++
|
||||
}
|
||||
if search != "" {
|
||||
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
|
||||
args = append(args, "%"+escapeLike(search)+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := strings.Join(where, " AND ")
|
||||
|
||||
// 计数
|
||||
var total int64
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, nil, fmt.Errorf("count channels: %w", err)
|
||||
}
|
||||
|
||||
pageSize := params.Limit() // 约束在 [1, 100]
|
||||
page := params.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`,
|
||||
whereClause, argIdx, argIdx+1,
|
||||
)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("query channels: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []service.Channel
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate channels: %w", err)
|
||||
}
|
||||
|
||||
// 批量加载分组 ID 和模型定价(避免 N+1)
|
||||
if len(channelIDs) > 0 {
|
||||
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
pages := 0
|
||||
if total > 0 {
|
||||
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
}
|
||||
|
||||
paginationResult := &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
}
|
||||
|
||||
return channels, paginationResult, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []service.Channel
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate channels: %w", err)
|
||||
}
|
||||
|
||||
if len(channelIDs) == 0 {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// 批量加载分组 ID
|
||||
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 批量加载模型定价
|
||||
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// --- 批量加载辅助方法 ---
|
||||
|
||||
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
|
||||
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT channel_id, group_id FROM channel_groups
|
||||
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load group ids: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
groupMap := make(map[int64][]int64, len(channelIDs))
|
||||
for rows.Next() {
|
||||
var channelID, groupID int64
|
||||
if err := rows.Scan(&channelID, &groupID); err != nil {
|
||||
return nil, fmt.Errorf("scan group id: %w", err)
|
||||
}
|
||||
groupMap[channelID] = append(groupMap[channelID], groupID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group ids: %w", err)
|
||||
}
|
||||
return groupMap, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// --- 分组关联 ---
|
||||
|
||||
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group ids: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan group id: %w", err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group ids: %w", err)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var channelID int64
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
|
||||
).Scan(&channelID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
return channelID, err
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
|
||||
pq.Array(groupIDs), channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get groups in other channels: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conflicting []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan conflicting group id: %w", err)
|
||||
}
|
||||
conflicting = append(conflicting, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
|
||||
}
|
||||
return conflicting, nil
|
||||
}
|
||||
285
backend/internal/repository/channel_repo_pricing.go
Normal file
285
backend/internal/repository/channel_repo_pricing.go
Normal file
@ -0,0 +1,285 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// --- 模型定价 ---
|
||||
|
||||
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list model pricing: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result, pricingIDs, err := scanModelPricingRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(pricingIDs) > 0 {
|
||||
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range result {
|
||||
result[i].Intervals = intervalMap[result[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
||||
return createModelPricingExec(ctx, r.db, pricing)
|
||||
}
|
||||
|
||||
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx,
|
||||
`UPDATE channel_model_pricing
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW()
|
||||
WHERE id = $8`,
|
||||
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update model pricing: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete model pricing: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
|
||||
})
|
||||
}
|
||||
|
||||
// --- 批量加载辅助方法 ---
|
||||
|
||||
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
|
||||
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load model pricing: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 按 channelID 分组
|
||||
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
|
||||
for _, p := range allPricing {
|
||||
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
|
||||
}
|
||||
|
||||
// 批量加载所有区间
|
||||
if len(allPricingIDs) > 0 {
|
||||
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for chID := range pricingMap {
|
||||
for i := range pricingMap[chID] {
|
||||
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pricingMap, nil
|
||||
}
|
||||
|
||||
// batchLoadIntervals 批量加载多个定价条目的区间
|
||||
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
|
||||
input_price, output_price, cache_write_price, cache_read_price,
|
||||
per_request_price, sort_order, created_at, updated_at
|
||||
FROM channel_pricing_intervals
|
||||
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
|
||||
pq.Array(pricingIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load intervals: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
|
||||
for rows.Next() {
|
||||
var iv service.PricingInterval
|
||||
if err := rows.Scan(
|
||||
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
|
||||
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
|
||||
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan interval: %w", err)
|
||||
}
|
||||
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate intervals: %w", err)
|
||||
}
|
||||
return intervalMap, nil
|
||||
}
|
||||
|
||||
// --- 共享 scan 辅助 ---
|
||||
|
||||
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
|
||||
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
|
||||
var result []service.ChannelModelPricing
|
||||
var pricingIDs []int64
|
||||
for rows.Next() {
|
||||
var p service.ChannelModelPricing
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||
p.Models = []string{}
|
||||
}
|
||||
pricingIDs = append(pricingIDs, p.ID)
|
||||
result = append(result, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
|
||||
}
|
||||
return result, pricingIDs, nil
|
||||
}
|
||||
|
||||
// --- 事务内辅助方法 ---
|
||||
|
||||
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
|
||||
type dbExec interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
|
||||
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
|
||||
return fmt.Errorf("delete old group associations: %w", err)
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := exec.ExecContext(ctx,
|
||||
`INSERT INTO channel_groups (channel_id, group_id)
|
||||
SELECT $1, unnest($2::bigint[])`,
|
||||
channelID, pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert group associations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
err = exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert model pricing: %w", err)
|
||||
}
|
||||
|
||||
for i := range pricing.Intervals {
|
||||
pricing.Intervals[i].PricingID = pricing.ID
|
||||
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
|
||||
return exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_pricing_intervals
|
||||
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
|
||||
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
|
||||
iv.PerRequestPrice, iv.SortOrder,
|
||||
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
|
||||
}
|
||||
|
||||
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
|
||||
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
|
||||
return fmt.Errorf("delete old model pricing: %w", err)
|
||||
}
|
||||
for i := range pricingList {
|
||||
pricingList[i].ChannelID = channelID
|
||||
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
|
||||
return fmt.Errorf("insert model pricing: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isUniqueViolation 检查 pq 唯一约束违反错误
|
||||
func isUniqueViolation(err error) bool {
|
||||
if pqErr, ok := err.(*pq.Error); ok {
|
||||
return pqErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
|
||||
func escapeLike(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return s
|
||||
}
|
||||
@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserGroupRateRepository,
|
||||
NewErrorPassthroughRepository,
|
||||
NewTLSFingerprintProfileRepository,
|
||||
NewChannelRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
@ -87,6 +87,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
|
||||
// 渠道管理
|
||||
registerChannelRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@ -567,3 +570,14 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
|
||||
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
|
||||
}
|
||||
}
|
||||
|
||||
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
channels := admin.Group("/channels")
|
||||
{
|
||||
channels.GET("", h.Admin.Channel.List)
|
||||
channels.GET("/:id", h.Admin.Channel.GetByID)
|
||||
channels.POST("", h.Admin.Channel.Create)
|
||||
channels.PUT("/:id", h.Admin.Channel.Update)
|
||||
channels.DELETE("/:id", h.Admin.Channel.Delete)
|
||||
}
|
||||
}
|
||||
|
||||
@ -371,13 +371,193 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
||||
// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价
|
||||
func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelPricing == nil {
|
||||
return pricing, nil
|
||||
}
|
||||
if channelPricing.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *channelPricing.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *channelPricing.InputPrice
|
||||
}
|
||||
if channelPricing.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *channelPricing.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice
|
||||
}
|
||||
if channelPricing.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice
|
||||
}
|
||||
if channelPricing.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
|
||||
}
|
||||
return pricing, nil
|
||||
}
|
||||
|
||||
// CalculateCostWithChannel 使用渠道定价计算费用
|
||||
// Deprecated: 使用 CalculateCostUnified 代替
|
||||
func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing)
|
||||
}
|
||||
|
||||
// --- 统一计费入口 ---
|
||||
|
||||
// CostInput 统一计费输入
|
||||
type CostInput struct {
|
||||
Ctx context.Context
|
||||
Model string
|
||||
GroupID *int64 // 用于渠道定价查找
|
||||
Tokens UsageTokens
|
||||
RequestCount int // 按次计费时使用
|
||||
SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
|
||||
RateMultiplier float64
|
||||
ServiceTier string // "priority","flex","" 等
|
||||
Resolver *ModelPricingResolver // 定价解析器
|
||||
}
|
||||
|
||||
// CalculateCostUnified 统一计费入口,支持三种计费模式。
|
||||
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
|
||||
func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) {
|
||||
if input.Resolver == nil {
|
||||
// 无 Resolver,回退到旧路径
|
||||
return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil)
|
||||
}
|
||||
|
||||
resolved := input.Resolver.Resolve(input.Ctx, PricingInput{
|
||||
Model: input.Model,
|
||||
GroupID: input.GroupID,
|
||||
})
|
||||
|
||||
if input.RateMultiplier <= 0 {
|
||||
input.RateMultiplier = 1.0
|
||||
}
|
||||
|
||||
switch resolved.Mode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return s.calculatePerRequestCost(resolved, input)
|
||||
default: // BillingModeToken
|
||||
return s.calculateTokenCost(resolved, input)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateTokenCost 按 token 区间计费
|
||||
func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
|
||||
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
||||
if pricing == nil {
|
||||
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
|
||||
}
|
||||
|
||||
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
|
||||
if usePriorityServiceTierPricing(input.ServiceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(input.ServiceTier)
|
||||
}
|
||||
|
||||
// 长上下文定价(仅在无区间定价时应用,区间定价已包含上下文分层)
|
||||
if len(resolved.Intervals) == 0 && s.shouldApplySessionLongContextPricing(input.Tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
|
||||
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(input.Tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(input.Tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
|
||||
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
// calculatePerRequestCost 按次/图片计费
|
||||
func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
count := input.RequestCount
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
|
||||
var unitPrice float64
|
||||
|
||||
if input.SizeTier != "" {
|
||||
unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier)
|
||||
}
|
||||
|
||||
if unitPrice == 0 {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext)
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(count)
|
||||
actualCost := totalCost * input.RateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
var pricing *ModelPricing
|
||||
var err error
|
||||
if channelPricing != nil {
|
||||
pricing, err = s.GetModelPricingWithChannel(model, channelPricing)
|
||||
} else {
|
||||
pricing, err = s.GetModelPricing(model)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
171
backend/internal/service/channel.go
Normal file
171
backend/internal/service/channel.go
Normal file
@ -0,0 +1,171 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BillingMode 计费模式
|
||||
type BillingMode string
|
||||
|
||||
const (
|
||||
BillingModeToken BillingMode = "token" // 按 token 区间计费
|
||||
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
|
||||
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
|
||||
)
|
||||
|
||||
// IsValid 检查 BillingMode 是否为合法值
|
||||
func (m BillingMode) IsValid() bool {
|
||||
switch m {
|
||||
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Channel 渠道实体
|
||||
type Channel struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联的分组 ID 列表
|
||||
GroupIDs []int64
|
||||
// 模型定价列表
|
||||
ModelPricing []ChannelModelPricing
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
type ChannelModelPricing struct {
|
||||
ID int64
|
||||
ChannelID int64
|
||||
Models []string // 绑定的模型列表
|
||||
BillingMode BillingMode // 计费模式
|
||||
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
|
||||
OutputPrice *float64 // 每 token 输出价格(USD)
|
||||
CacheWritePrice *float64 // 缓存写入价格
|
||||
CacheReadPrice *float64 // 缓存读取价格
|
||||
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
|
||||
Intervals []PricingInterval // 区间定价列表
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层)
|
||||
type PricingInterval struct {
|
||||
ID int64
|
||||
PricingID int64
|
||||
MinTokens int // 区间下界(含)
|
||||
MaxTokens *int // 区间上界(不含),nil = 无上限
|
||||
TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等)
|
||||
InputPrice *float64 // token 模式:每 token 输入价
|
||||
OutputPrice *float64 // token 模式:每 token 输出价
|
||||
CacheWritePrice *float64 // token 模式:缓存写入价
|
||||
CacheReadPrice *float64 // token 模式:缓存读取价
|
||||
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
|
||||
SortOrder int
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// IsActive 判断渠道是否启用
|
||||
func (c *Channel) IsActive() bool {
|
||||
return c.Status == StatusActive
|
||||
}
|
||||
|
||||
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
|
||||
// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。
|
||||
// 返回值拷贝,不污染缓存。
|
||||
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// 第一轮:精确匹配
|
||||
for i := range c.ModelPricing {
|
||||
for _, m := range c.ModelPricing[i].Models {
|
||||
if strings.ToLower(m) == modelLower {
|
||||
cp := c.ModelPricing[i].Clone()
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二轮:通配符匹配(仅支持末尾 *)
|
||||
for i := range c.ModelPricing {
|
||||
for _, m := range c.ModelPricing[i].Models {
|
||||
mLower := strings.ToLower(m)
|
||||
if strings.HasSuffix(mLower, "*") {
|
||||
prefix := strings.TrimSuffix(mLower, "*")
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
cp := c.ModelPricing[i].Clone()
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
|
||||
// 通用辅助函数,供 GetIntervalForContext、ModelPricingResolver 等复用。
|
||||
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
|
||||
for i := range intervals {
|
||||
iv := &intervals[i]
|
||||
if totalTokens >= iv.MinTokens && (iv.MaxTokens == nil || totalTokens < *iv.MaxTokens) {
|
||||
return iv
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
|
||||
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
|
||||
return FindMatchingInterval(p.Intervals, totalTokens)
|
||||
}
|
||||
|
||||
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
|
||||
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
|
||||
labelLower := strings.ToLower(label)
|
||||
for i := range p.Intervals {
|
||||
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
|
||||
return &p.Intervals[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
|
||||
func (p ChannelModelPricing) Clone() ChannelModelPricing {
|
||||
cp := p
|
||||
if p.Models != nil {
|
||||
cp.Models = make([]string, len(p.Models))
|
||||
copy(cp.Models, p.Models)
|
||||
}
|
||||
if p.Intervals != nil {
|
||||
cp.Intervals = make([]PricingInterval, len(p.Intervals))
|
||||
copy(cp.Intervals, p.Intervals)
|
||||
}
|
||||
return cp
|
||||
}
|
||||
|
||||
// Clone 返回 Channel 的深拷贝
|
||||
func (c *Channel) Clone() *Channel {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *c
|
||||
if c.GroupIDs != nil {
|
||||
cp.GroupIDs = make([]int64, len(c.GroupIDs))
|
||||
copy(cp.GroupIDs, c.GroupIDs)
|
||||
}
|
||||
if c.ModelPricing != nil {
|
||||
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
|
||||
for i := range c.ModelPricing {
|
||||
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
338
backend/internal/service/channel_service.go
Normal file
338
backend/internal/service/channel_service.go
Normal file
@ -0,0 +1,338 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
|
||||
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
|
||||
ErrGroupAlreadyInChannel = infraerrors.Conflict(
|
||||
"GROUP_ALREADY_IN_CHANNEL",
|
||||
"one or more groups already belong to another channel",
|
||||
)
|
||||
)
|
||||
|
||||
// ChannelRepository 渠道数据访问接口
|
||||
type ChannelRepository interface {
|
||||
Create(ctx context.Context, channel *Channel) error
|
||||
GetByID(ctx context.Context, id int64) (*Channel, error)
|
||||
Update(ctx context.Context, channel *Channel) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
|
||||
ListAll(ctx context.Context) ([]Channel, error)
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error)
|
||||
|
||||
// 分组关联
|
||||
GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error)
|
||||
SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error
|
||||
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
|
||||
|
||||
// 模型定价
|
||||
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
|
||||
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
|
||||
DeleteModelPricing(ctx context.Context, id int64) error
|
||||
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
|
||||
}
|
||||
|
||||
// channelCache 渠道缓存快照
|
||||
type channelCache struct {
|
||||
// byID: channelID -> *Channel(含 ModelPricing)
|
||||
byID map[int64]*Channel
|
||||
// byGroupID: groupID -> channelID
|
||||
byGroupID map[int64]int64
|
||||
loadedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
channelCacheTTL = 60 * time.Second
|
||||
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
|
||||
channelCacheDBTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// ChannelService 渠道管理服务
|
||||
type ChannelService struct {
|
||||
repo ChannelRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
|
||||
cache atomic.Value // *channelCache
|
||||
cacheSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewChannelService 创建渠道服务实例
|
||||
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
|
||||
s := &ChannelService{
|
||||
repo: repo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// loadCache 加载或返回缓存的渠道数据
|
||||
func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
|
||||
if cached, ok := s.cache.Load().(*channelCache); ok {
|
||||
if time.Since(cached.loadedAt) < channelCacheTTL {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) {
|
||||
// 双重检查
|
||||
if cached, ok := s.cache.Load().(*channelCache); ok {
|
||||
if time.Since(cached.loadedAt) < channelCacheTTL {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
return s.buildCache(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*channelCache), nil
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
channels, err := s.repo.ListAll(dbCtx)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := &channelCache{
|
||||
byID: make(map[int64]*Channel),
|
||||
byGroupID: make(map[int64]int64),
|
||||
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
|
||||
}
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
cache := &channelCache{
|
||||
byID: make(map[int64]*Channel, len(channels)),
|
||||
byGroupID: make(map[int64]int64),
|
||||
loadedAt: time.Now(),
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.byGroupID[gid] = ch.ID
|
||||
}
|
||||
}
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||
func (s *ChannelService) invalidateCache() {
|
||||
s.cache.Store((*channelCache)(nil))
|
||||
s.cacheSF.Forget("channel_cache")
|
||||
}
|
||||
|
||||
// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取)
|
||||
// 返回深拷贝,不污染缓存。
|
||||
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
|
||||
cache, err := s.loadCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelID, ok := cache.byGroupID[groupID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ch, ok := cache.byID[channelID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !ch.IsActive() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return ch.Clone(), nil
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径)
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
ch, err := s.GetChannelForGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get channel for group", "group_id", groupID, "error", err)
|
||||
return nil
|
||||
}
|
||||
if ch == nil {
|
||||
return nil
|
||||
}
|
||||
return ch.GetModelPricing(model)
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
// Create 创建渠道
|
||||
func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) {
|
||||
exists, err := s.repo.ExistsByName(ctx, input.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if len(input.GroupIDs) > 0 {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Status: StatusActive,
|
||||
GroupIDs: input.GroupIDs,
|
||||
ModelPricing: input.ModelPricing,
|
||||
}
|
||||
|
||||
if err := s.repo.Create(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("create channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
return s.repo.GetByID(ctx, channel.ID)
|
||||
}
|
||||
|
||||
// GetByID 获取渠道详情
|
||||
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Update 更新渠道
|
||||
func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) {
|
||||
channel, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
// 失效关联分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err)
|
||||
}
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Delete 删除渠道
|
||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
// 先获取关联分组用于失效缓存
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||
}
|
||||
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取渠道列表
|
||||
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// --- Input types ---
|
||||
|
||||
// CreateChannelInput 创建渠道输入
|
||||
type CreateChannelInput struct {
|
||||
Name string
|
||||
Description string
|
||||
GroupIDs []int64
|
||||
ModelPricing []ChannelModelPricing
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
type UpdateChannelInput struct {
|
||||
Name string
|
||||
Description *string
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ModelPricing *[]ChannelModelPricing
|
||||
}
|
||||
210
backend/internal/service/channel_test.go
Normal file
210
backend/internal/service/channel_test.go
Normal file
@ -0,0 +1,210 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func channelTestPtrFloat64(v float64) *float64 { return &v }
|
||||
func channelTestPtrInt(v int) *int { return &v }
|
||||
|
||||
func TestGetModelPricing(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
{ID: 2, Models: []string{"claude-*"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(5e-6)},
|
||||
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
}{
|
||||
{"exact match", "claude-sonnet-4", 1, false},
|
||||
{"case insensitive", "Claude-Sonnet-4", 1, false},
|
||||
{"wildcard match", "claude-opus-4-20250514", 2, false},
|
||||
{"exact takes priority over wildcard", "claude-sonnet-4", 1, false},
|
||||
{"not found", "gemini-3.1-pro", 0, true},
|
||||
{"per_request model", "gpt-5.1", 3, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ch.GetModelPricing(tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantID, result.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
|
||||
ch := &Channel{
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := ch.GetModelPricing("claude-sonnet-4")
|
||||
require.NotNil(t, result)
|
||||
|
||||
// Modify the returned copy's slice — original should be unchanged
|
||||
result.Models = append(result.Models, "hacked")
|
||||
|
||||
// Original should be unchanged
|
||||
require.Equal(t, 1, len(ch.ModelPricing[0].Models))
|
||||
}
|
||||
|
||||
func TestGetModelPricing_EmptyPricing(t *testing.T) {
|
||||
ch := &Channel{ModelPricing: nil}
|
||||
require.Nil(t, ch.GetModelPricing("any-model"))
|
||||
|
||||
ch2 := &Channel{ModelPricing: []ChannelModelPricing{}}
|
||||
require.Nil(t, ch2.GetModelPricing("any-model"))
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens int
|
||||
wantPrice *float64
|
||||
wantNil bool
|
||||
}{
|
||||
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
|
||||
{"boundary: at min of second", 128000, channelTestPtrFloat64(2e-6), false},
|
||||
{"boundary: at max of first (exclusive)", 128000, channelTestPtrFloat64(2e-6), false},
|
||||
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
|
||||
{"zero tokens", 0, channelTestPtrFloat64(1e-6), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := p.GetIntervalForContext(tt.tokens)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext_NoMatch(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
|
||||
},
|
||||
}
|
||||
require.Nil(t, p.GetIntervalForContext(5000))
|
||||
require.Nil(t, p.GetIntervalForContext(50000))
|
||||
}
|
||||
|
||||
func TestGetIntervalForContext_Empty(t *testing.T) {
|
||||
p := &ChannelModelPricing{Intervals: nil}
|
||||
require.Nil(t, p.GetIntervalForContext(1000))
|
||||
}
|
||||
|
||||
func TestGetTierByLabel(t *testing.T) {
|
||||
p := &ChannelModelPricing{
|
||||
Intervals: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
|
||||
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
label string
|
||||
wantNil bool
|
||||
want float64
|
||||
}{
|
||||
{"exact match", "1K", false, 0.04},
|
||||
{"case insensitive", "hd", false, 0.12},
|
||||
{"not found", "4K", true, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := p.GetTierByLabel(tt.label)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, result)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTierByLabel_Empty(t *testing.T) {
|
||||
p := &ChannelModelPricing{Intervals: nil}
|
||||
require.Nil(t, p.GetTierByLabel("1K"))
|
||||
}
|
||||
|
||||
func TestChannelClone(t *testing.T) {
|
||||
original := &Channel{
|
||||
ID: 1,
|
||||
Name: "test",
|
||||
GroupIDs: []int64{10, 20},
|
||||
ModelPricing: []ChannelModelPricing{
|
||||
{
|
||||
ID: 100,
|
||||
Models: []string{"model-a"},
|
||||
InputPrice: channelTestPtrFloat64(5e-6),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cloned := original.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.Equal(t, original.ID, cloned.ID)
|
||||
require.Equal(t, original.Name, cloned.Name)
|
||||
|
||||
// Modify clone slices — original should not change
|
||||
cloned.GroupIDs[0] = 999
|
||||
require.Equal(t, int64(10), original.GroupIDs[0])
|
||||
|
||||
cloned.ModelPricing[0].Models[0] = "hacked"
|
||||
require.Equal(t, "model-a", original.ModelPricing[0].Models[0])
|
||||
}
|
||||
|
||||
func TestChannelClone_Nil(t *testing.T) {
|
||||
var ch *Channel
|
||||
require.Nil(t, ch.Clone())
|
||||
}
|
||||
|
||||
func TestChannelModelPricingClone(t *testing.T) {
|
||||
original := ChannelModelPricing{
|
||||
Models: []string{"a", "b"},
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, TierLabel: "tier1"},
|
||||
},
|
||||
}
|
||||
|
||||
cloned := original.Clone()
|
||||
|
||||
// Modify clone slices — original unchanged
|
||||
cloned.Models[0] = "hacked"
|
||||
require.Equal(t, "a", original.Models[0])
|
||||
|
||||
cloned.Intervals[0].TierLabel = "hacked"
|
||||
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
|
||||
}
|
||||
@ -41,6 +41,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -568,6 +568,7 @@ type GatewayService struct {
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
channelService *ChannelService
|
||||
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
}
|
||||
@ -597,6 +598,7 @@ func NewGatewayService(
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
tlsFPProfileService *TLSFingerprintProfileService,
|
||||
channelService *ChannelService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@ -629,6 +631,7 @@ func NewGatewayService(
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
channelService: channelService,
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
@ -7771,7 +7774,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
// 渠道定价覆盖
|
||||
var chPricing *ChannelModelPricing
|
||||
if s.channelService != nil && apiKey.Group != nil {
|
||||
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
|
||||
}
|
||||
if chPricing != nil {
|
||||
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
@ -7959,7 +7971,16 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
// 渠道定价覆盖
|
||||
var chPricing2 *ChannelModelPricing
|
||||
if s.channelService != nil && apiKey.Group != nil {
|
||||
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
|
||||
}
|
||||
if chPricing2 != nil {
|
||||
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
|
||||
198
backend/internal/service/model_pricing_resolver.go
Normal file
198
backend/internal/service/model_pricing_resolver.go
Normal file
@ -0,0 +1,198 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// ResolvedPricing 统一定价解析结果
|
||||
type ResolvedPricing struct {
|
||||
// Mode 计费模式
|
||||
Mode BillingMode
|
||||
|
||||
// Token 模式:基础定价(来自 LiteLLM 或 fallback)
|
||||
BasePricing *ModelPricing
|
||||
|
||||
// Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段)
|
||||
Intervals []PricingInterval
|
||||
|
||||
// 按次/图片模式:分层定价
|
||||
RequestTiers []PricingInterval
|
||||
|
||||
// 来源标识
|
||||
Source string // "channel", "litellm", "fallback"
|
||||
|
||||
// 是否支持缓存细分
|
||||
SupportsCacheBreakdown bool
|
||||
}
|
||||
|
||||
// ModelPricingResolver 统一模型定价解析器。
|
||||
// 解析链:Channel → LiteLLM → Fallback。
|
||||
type ModelPricingResolver struct {
|
||||
channelService *ChannelService
|
||||
billingService *BillingService
|
||||
}
|
||||
|
||||
// NewModelPricingResolver 创建定价解析器实例
|
||||
func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver {
|
||||
return &ModelPricingResolver{
|
||||
channelService: channelService,
|
||||
billingService: billingService,
|
||||
}
|
||||
}
|
||||
|
||||
// PricingInput 定价解析输入
|
||||
type PricingInput struct {
|
||||
Model string
|
||||
GroupID *int64 // nil 表示不检查渠道
|
||||
}
|
||||
|
||||
// Resolve 解析模型定价。
|
||||
// 1. 获取基础定价(LiteLLM → Fallback)
|
||||
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
|
||||
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
|
||||
// 1. 获取基础定价
|
||||
basePricing, source := r.resolveBasePricing(input.Model)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Source: source,
|
||||
SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown,
|
||||
}
|
||||
|
||||
// 2. 如果有 GroupID,尝试渠道覆盖
|
||||
if input.GroupID != nil {
|
||||
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价
|
||||
func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) {
|
||||
pricing, err := r.billingService.GetModelPricing(model)
|
||||
if err != nil {
|
||||
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
|
||||
"model", model, "error", err)
|
||||
return nil, "fallback"
|
||||
}
|
||||
return pricing, "litellm"
|
||||
}
|
||||
|
||||
// applyChannelOverrides 应用渠道定价覆盖
|
||||
func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) {
|
||||
chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model)
|
||||
if chPricing == nil {
|
||||
return
|
||||
}
|
||||
|
||||
resolved.Source = "channel"
|
||||
resolved.Mode = chPricing.BillingMode
|
||||
if resolved.Mode == "" {
|
||||
resolved.Mode = BillingModeToken
|
||||
}
|
||||
|
||||
switch resolved.Mode {
|
||||
case BillingModeToken:
|
||||
r.applyTokenOverrides(chPricing, resolved)
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
r.applyRequestTierOverrides(chPricing, resolved)
|
||||
}
|
||||
}
|
||||
|
||||
// applyTokenOverrides 应用 token 模式的渠道覆盖
|
||||
func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
|
||||
// 如果有区间定价,使用区间
|
||||
if len(chPricing.Intervals) > 0 {
|
||||
resolved.Intervals = chPricing.Intervals
|
||||
return
|
||||
}
|
||||
|
||||
// 否则用 flat 字段覆盖 BasePricing
|
||||
if resolved.BasePricing == nil {
|
||||
resolved.BasePricing = &ModelPricing{}
|
||||
}
|
||||
|
||||
if chPricing.InputPrice != nil {
|
||||
resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice
|
||||
resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice
|
||||
}
|
||||
if chPricing.OutputPrice != nil {
|
||||
resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice
|
||||
resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice
|
||||
}
|
||||
if chPricing.CacheWritePrice != nil {
|
||||
resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice
|
||||
resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice
|
||||
resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice
|
||||
}
|
||||
if chPricing.CacheReadPrice != nil {
|
||||
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
|
||||
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
|
||||
}
|
||||
}
|
||||
|
||||
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
|
||||
func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
|
||||
resolved.RequestTiers = chPricing.Intervals
|
||||
}
|
||||
|
||||
// GetIntervalPricing 根据 context token 数获取区间定价。
|
||||
// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。
|
||||
func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing {
|
||||
if len(resolved.Intervals) == 0 {
|
||||
return resolved.BasePricing
|
||||
}
|
||||
|
||||
iv := FindMatchingInterval(resolved.Intervals, totalContextTokens)
|
||||
if iv == nil {
|
||||
return resolved.BasePricing
|
||||
}
|
||||
|
||||
return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown)
|
||||
}
|
||||
|
||||
// intervalToModelPricing 将区间定价转换为 ModelPricing
|
||||
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing {
|
||||
pricing := &ModelPricing{
|
||||
SupportsCacheBreakdown: supportsCacheBreakdown,
|
||||
}
|
||||
if iv.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *iv.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *iv.InputPrice
|
||||
}
|
||||
if iv.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *iv.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *iv.OutputPrice
|
||||
}
|
||||
if iv.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *iv.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *iv.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *iv.CacheWritePrice
|
||||
}
|
||||
if iv.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *iv.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice
|
||||
}
|
||||
return pricing
|
||||
}
|
||||
|
||||
// GetRequestTierPrice 根据层级标签获取按次价格
|
||||
func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 {
|
||||
for _, tier := range resolved.RequestTiers {
|
||||
if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil {
|
||||
return *tier.PerRequestPrice
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetRequestTierPriceByContext 根据 context token 数获取按次价格
|
||||
func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 {
|
||||
iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens)
|
||||
if iv != nil && iv.PerRequestPrice != nil {
|
||||
return *iv.PerRequestPrice
|
||||
}
|
||||
return 0
|
||||
}
|
||||
164
backend/internal/service/model_pricing_resolver_test.go
Normal file
164
backend/internal/service/model_pricing_resolver_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resolverPtrFloat64(v float64) *float64 { return &v }
|
||||
func resolverPtrInt(v int) *int { return &v }
|
||||
|
||||
func newTestBillingServiceForResolver() *BillingService {
|
||||
bs := &BillingService{
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6,
|
||||
OutputPricePerToken: 15e-6,
|
||||
CacheCreationPricePerToken: 3.75e-6,
|
||||
CacheReadPricePerToken: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
return bs
|
||||
}
|
||||
|
||||
func TestResolve_NoGroupID(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "claude-sonnet-4",
|
||||
GroupID: nil,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Equal(t, BillingModeToken, resolved.Mode)
|
||||
require.NotNil(t, resolved.BasePricing)
|
||||
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
|
||||
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
|
||||
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
|
||||
require.Equal(t, "litellm", resolved.Source)
|
||||
}
|
||||
|
||||
func TestResolve_UnknownModel(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := r.Resolve(context.Background(), PricingInput{
|
||||
Model: "unknown-model-xyz",
|
||||
GroupID: nil,
|
||||
})
|
||||
|
||||
require.NotNil(t, resolved)
|
||||
require.Nil(t, resolved.BasePricing)
|
||||
// Unknown model: GetModelPricing returns error, source is "fallback"
|
||||
require.Equal(t, "fallback", resolved.Source)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_NoIntervals(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
basePricing := &ModelPricing{InputPricePerToken: 5e-6}
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: nil,
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 50000)
|
||||
require.Equal(t, basePricing, result)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
|
||||
SupportsCacheBreakdown: true,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)},
|
||||
{MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 50000)
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12)
|
||||
require.True(t, result.SupportsCacheBreakdown)
|
||||
|
||||
result2 := r.GetIntervalPricing(resolved, 200000)
|
||||
require.NotNil(t, result2)
|
||||
require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
basePricing := &ModelPricing{InputPricePerToken: 99e-6}
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModeToken,
|
||||
BasePricing: basePricing,
|
||||
Intervals: []PricingInterval{
|
||||
{MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)},
|
||||
},
|
||||
}
|
||||
|
||||
result := r.GetIntervalPricing(resolved, 5000)
|
||||
require.Equal(t, basePricing, result)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPrice(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)},
|
||||
{TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPriceByContext(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)},
|
||||
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
|
||||
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
|
||||
}
|
||||
|
||||
func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
|
||||
bs := newTestBillingServiceForResolver()
|
||||
r := NewModelPricingResolver(&ChannelService{}, bs)
|
||||
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: BillingModePerRequest,
|
||||
RequestTiers: []PricingInterval{
|
||||
{TierLabel: "1K", PerRequestPrice: nil},
|
||||
},
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
|
||||
}
|
||||
@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideScheduledTestService,
|
||||
ProvideScheduledTestRunnerService,
|
||||
NewGroupCapacityService,
|
||||
NewChannelService,
|
||||
NewModelPricingResolver,
|
||||
)
|
||||
|
||||
56
backend/migrations/081_create_channels.sql
Normal file
56
backend/migrations/081_create_channels.sql
Normal file
@ -0,0 +1,56 @@
|
||||
-- Create channels table for managing pricing channels.
|
||||
-- A channel groups multiple groups together and provides custom model pricing.
|
||||
|
||||
SET LOCAL lock_timeout = '5s';
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
-- 渠道表
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 渠道名称唯一索引
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
|
||||
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
|
||||
|
||||
-- 渠道-分组关联表(每个分组只能属于一个渠道)
|
||||
CREATE TABLE IF NOT EXISTS channel_groups (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
|
||||
|
||||
-- 渠道模型定价表(一条定价可绑定多个模型)
|
||||
CREATE TABLE IF NOT EXISTS channel_model_pricing (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||
models JSONB NOT NULL DEFAULT '[]',
|
||||
input_price NUMERIC(20,12),
|
||||
output_price NUMERIC(20,12),
|
||||
cache_write_price NUMERIC(20,12),
|
||||
cache_read_price NUMERIC(20,12),
|
||||
image_output_price NUMERIC(20,8),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
|
||||
|
||||
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
|
||||
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
|
||||
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
|
||||
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
|
||||
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认';
|
||||
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认';
|
||||
67
backend/migrations/082_refactor_channel_pricing.sql
Normal file
67
backend/migrations/082_refactor_channel_pricing.sql
Normal file
@ -0,0 +1,67 @@
|
||||
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
|
||||
-- Supports three billing modes: token (per-token with context intervals),
|
||||
-- per_request (per-request with context-size tiers), and image (per-image).
|
||||
|
||||
SET LOCAL lock_timeout = '5s';
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
-- 1. 为 channel_model_pricing 添加 billing_mode 列
|
||||
ALTER TABLE channel_model_pricing
|
||||
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
|
||||
|
||||
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)';
|
||||
|
||||
-- 2. 创建区间定价子表
|
||||
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
|
||||
min_tokens INT NOT NULL DEFAULT 0,
|
||||
max_tokens INT,
|
||||
tier_label VARCHAR(50),
|
||||
input_price NUMERIC(20,12),
|
||||
output_price NUMERIC(20,12),
|
||||
cache_write_price NUMERIC(20,12),
|
||||
cache_read_price NUMERIC(20,12),
|
||||
per_request_price NUMERIC(20,12),
|
||||
sort_order INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
|
||||
ON channel_pricing_intervals (pricing_id);
|
||||
|
||||
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
|
||||
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
|
||||
|
||||
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
|
||||
-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
|
||||
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
|
||||
SELECT
|
||||
cmp.id,
|
||||
0,
|
||||
NULL,
|
||||
cmp.input_price,
|
||||
cmp.output_price,
|
||||
cmp.cache_write_price,
|
||||
cmp.cache_read_price,
|
||||
0
|
||||
FROM channel_model_pricing cmp
|
||||
WHERE cmp.billing_mode = 'token'
|
||||
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
|
||||
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
|
||||
);
|
||||
|
||||
-- 4. 迁移 image_output_price 为 image 模式的区间条目
|
||||
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
|
||||
-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
|
||||
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
|
||||
121
frontend/src/api/admin/channels.ts
Normal file
121
frontend/src/api/admin/channels.ts
Normal file
@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Admin Channels API endpoints
|
||||
* Handles channel management for administrators
|
||||
*/
|
||||
|
||||
import { apiClient } from '../client'
|
||||
|
||||
export type BillingMode = 'token' | 'per_request' | 'image'
|
||||
|
||||
export interface PricingInterval {
|
||||
id?: number
|
||||
min_tokens: number
|
||||
max_tokens: number | null
|
||||
tier_label: string
|
||||
input_price: number | null
|
||||
output_price: number | null
|
||||
cache_write_price: number | null
|
||||
cache_read_price: number | null
|
||||
per_request_price: number | null
|
||||
sort_order: number
|
||||
}
|
||||
|
||||
export interface ChannelModelPricing {
|
||||
id?: number
|
||||
models: string[]
|
||||
billing_mode: BillingMode
|
||||
input_price: number | null
|
||||
output_price: number | null
|
||||
cache_write_price: number | null
|
||||
cache_read_price: number | null
|
||||
image_output_price: number | null
|
||||
intervals: PricingInterval[]
|
||||
}
|
||||
|
||||
export interface Channel {
|
||||
id: number
|
||||
name: string
|
||||
description: string
|
||||
status: string
|
||||
group_ids: number[]
|
||||
model_pricing: ChannelModelPricing[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface CreateChannelRequest {
|
||||
name: string
|
||||
description?: string
|
||||
group_ids?: number[]
|
||||
model_pricing?: ChannelModelPricing[]
|
||||
}
|
||||
|
||||
export interface UpdateChannelRequest {
|
||||
name?: string
|
||||
description?: string
|
||||
status?: string
|
||||
group_ids?: number[]
|
||||
model_pricing?: ChannelModelPricing[]
|
||||
}
|
||||
|
||||
interface PaginatedResponse<T> {
|
||||
items: T[]
|
||||
total: number
|
||||
}
|
||||
|
||||
/**
|
||||
* List channels with pagination
|
||||
*/
|
||||
export async function list(
|
||||
page: number = 1,
|
||||
pageSize: number = 20,
|
||||
filters?: {
|
||||
status?: string
|
||||
search?: string
|
||||
},
|
||||
options?: { signal?: AbortSignal }
|
||||
): Promise<PaginatedResponse<Channel>> {
|
||||
const { data } = await apiClient.get<PaginatedResponse<Channel>>('/admin/channels', {
|
||||
params: {
|
||||
page,
|
||||
page_size: pageSize,
|
||||
...filters
|
||||
},
|
||||
signal: options?.signal
|
||||
})
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Get channel by ID
|
||||
*/
|
||||
export async function getById(id: number): Promise<Channel> {
|
||||
const { data } = await apiClient.get<Channel>(`/admin/channels/${id}`)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new channel
|
||||
*/
|
||||
export async function create(req: CreateChannelRequest): Promise<Channel> {
|
||||
const { data } = await apiClient.post<Channel>('/admin/channels', req)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a channel
|
||||
*/
|
||||
export async function update(id: number, req: UpdateChannelRequest): Promise<Channel> {
|
||||
const { data } = await apiClient.put<Channel>(`/admin/channels/${id}`, req)
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a channel
|
||||
*/
|
||||
export async function remove(id: number): Promise<void> {
|
||||
await apiClient.delete(`/admin/channels/${id}`)
|
||||
}
|
||||
|
||||
const channelsAPI = { list, getById, create, update, remove }
|
||||
export default channelsAPI
|
||||
@ -25,6 +25,7 @@ import apiKeysAPI from './apiKeys'
|
||||
import scheduledTestsAPI from './scheduledTests'
|
||||
import backupAPI from './backup'
|
||||
import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
|
||||
import channelsAPI from './channels'
|
||||
|
||||
/**
|
||||
* Unified admin API object for convenient access
|
||||
@ -51,7 +52,8 @@ export const adminAPI = {
|
||||
apiKeys: apiKeysAPI,
|
||||
scheduledTests: scheduledTestsAPI,
|
||||
backup: backupAPI,
|
||||
tlsFingerprintProfiles: tlsFingerprintProfileAPI
|
||||
tlsFingerprintProfiles: tlsFingerprintProfileAPI,
|
||||
channels: channelsAPI
|
||||
}
|
||||
|
||||
export {
|
||||
@ -76,7 +78,8 @@ export {
|
||||
apiKeysAPI,
|
||||
scheduledTestsAPI,
|
||||
backupAPI,
|
||||
tlsFingerprintProfileAPI
|
||||
tlsFingerprintProfileAPI,
|
||||
channelsAPI
|
||||
}
|
||||
|
||||
export default adminAPI
|
||||
|
||||
160
frontend/src/components/admin/channel/IntervalRow.vue
Normal file
160
frontend/src/components/admin/channel/IntervalRow.vue
Normal file
@ -0,0 +1,160 @@
|
||||
<template>
|
||||
<div class="flex items-start gap-2 rounded border border-gray-200 bg-white p-2 dark:border-dark-500 dark:bg-dark-700">
|
||||
<!-- Token mode: context range + prices -->
|
||||
<template v-if="mode === 'token'">
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.minTokens', 'Min (K)') }}</label>
|
||||
<input
|
||||
:value="interval.min_tokens"
|
||||
@input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
|
||||
type="number"
|
||||
min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.maxTokens', 'Max (K)') }}</label>
|
||||
<input
|
||||
:value="interval.max_tokens ?? ''"
|
||||
@input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
|
||||
type="number"
|
||||
min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
:placeholder="'∞'"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', 'Input') }}</label>
|
||||
<input
|
||||
:value="interval.input_price"
|
||||
@input="emitField('input_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', 'Output') }}</label>
|
||||
<input
|
||||
:value="interval.output_price"
|
||||
@input="emitField('output_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', 'Cache W') }}</label>
|
||||
<input
|
||||
:value="interval.cache_write_price"
|
||||
@input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', 'Cache R') }}</label>
|
||||
<input
|
||||
:value="interval.cache_read_price"
|
||||
@input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- Per-request / Image mode: tier label + price -->
|
||||
<template v-else>
|
||||
<div class="w-24">
|
||||
<label class="text-xs text-gray-400">
|
||||
{{ mode === 'image'
|
||||
? t('admin.channels.form.resolution', 'Resolution')
|
||||
: t('admin.channels.form.tierLabel', 'Tier')
|
||||
}}
|
||||
</label>
|
||||
<input
|
||||
:value="interval.tier_label"
|
||||
@input="emitField('tier_label', ($event.target as HTMLInputElement).value)"
|
||||
type="text"
|
||||
class="input mt-0.5 text-xs"
|
||||
:placeholder="mode === 'image' ? '1K / 2K / 4K' : ''"
|
||||
/>
|
||||
</div>
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.minTokens', 'Min') }}</label>
|
||||
<input
|
||||
:value="interval.min_tokens"
|
||||
@input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
|
||||
type="number"
|
||||
min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
<div class="w-20">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.maxTokens', 'Max') }}</label>
|
||||
<input
|
||||
:value="interval.max_tokens ?? ''"
|
||||
@input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
|
||||
type="number"
|
||||
min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
:placeholder="'∞'"
|
||||
/>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<label class="text-xs text-gray-400">{{ t('admin.channels.form.perRequestPrice', 'Price') }}</label>
|
||||
<input
|
||||
:value="interval.per_request_price"
|
||||
@input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-0.5 text-xs"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="emit('remove')"
|
||||
class="mt-4 rounded p-0.5 text-gray-400 hover:text-red-500"
|
||||
>
|
||||
<Icon name="x" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import type { IntervalFormEntry } from './types'
|
||||
import type { BillingMode } from '@/api/admin/channels'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
interval: IntervalFormEntry
|
||||
mode: BillingMode
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
update: [interval: IntervalFormEntry]
|
||||
remove: []
|
||||
}>()
|
||||
|
||||
function emitField(field: keyof IntervalFormEntry, value: string | number | null) {
|
||||
emit('update', { ...props.interval, [field]: value === '' ? null : value })
|
||||
}
|
||||
|
||||
function toInt(val: string): number {
|
||||
const n = parseInt(val, 10)
|
||||
return isNaN(n) ? 0 : n
|
||||
}
|
||||
|
||||
function toIntOrNull(val: string): number | null {
|
||||
if (val === '') return null
|
||||
const n = parseInt(val, 10)
|
||||
return isNaN(n) ? null : n
|
||||
}
|
||||
</script>
|
||||
260
frontend/src/components/admin/channel/PricingEntryCard.vue
Normal file
260
frontend/src/components/admin/channel/PricingEntryCard.vue
Normal file
@ -0,0 +1,260 @@
|
||||
<template>
|
||||
<div class="rounded-lg border border-gray-200 bg-gray-50 p-3 dark:border-dark-600 dark:bg-dark-800">
|
||||
<!-- Header: Models + Billing Mode + Remove -->
|
||||
<div class="mb-2 flex items-start gap-2">
|
||||
<div class="flex-1">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.models', 'Models (comma separated, supports *)') }}
|
||||
</label>
|
||||
<textarea
|
||||
:value="entry.modelsInput"
|
||||
@input="emit('update', { ...entry, modelsInput: ($event.target as HTMLTextAreaElement).value })"
|
||||
rows="2"
|
||||
class="input mt-1 text-sm"
|
||||
:placeholder="t('admin.channels.form.modelsPlaceholder', 'claude-sonnet-4-20250514, claude-opus-4-20250514, *')"
|
||||
></textarea>
|
||||
</div>
|
||||
<div class="w-40">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.billingMode', 'Billing Mode') }}
|
||||
</label>
|
||||
<Select
|
||||
:modelValue="entry.billing_mode"
|
||||
@update:modelValue="emit('update', { ...entry, billing_mode: $event as BillingMode, intervals: [] })"
|
||||
:options="billingModeOptions"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="emit('remove')"
|
||||
class="mt-5 rounded p-1 text-gray-400 hover:text-red-500"
|
||||
>
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Token mode: flat prices + intervals -->
|
||||
<div v-if="entry.billing_mode === 'token'">
|
||||
<!-- Flat prices (used when no intervals) -->
|
||||
<div class="grid grid-cols-2 gap-2 sm:grid-cols-4">
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.inputPrice', 'Input Price') }}
|
||||
</label>
|
||||
<input
|
||||
:value="entry.input_price"
|
||||
@input="emitField('input_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-1 text-sm"
|
||||
:placeholder="t('admin.channels.form.pricePlaceholder', 'Default')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.outputPrice', 'Output Price') }}
|
||||
</label>
|
||||
<input
|
||||
:value="entry.output_price"
|
||||
@input="emitField('output_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-1 text-sm"
|
||||
:placeholder="t('admin.channels.form.pricePlaceholder', 'Default')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.cacheWritePrice', 'Cache Write') }}
|
||||
</label>
|
||||
<input
|
||||
:value="entry.cache_write_price"
|
||||
@input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-1 text-sm"
|
||||
:placeholder="t('admin.channels.form.pricePlaceholder', 'Default')"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.cacheReadPrice', 'Cache Read') }}
|
||||
</label>
|
||||
<input
|
||||
:value="entry.cache_read_price"
|
||||
@input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-1 text-sm"
|
||||
:placeholder="t('admin.channels.form.pricePlaceholder', 'Default')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Token intervals -->
|
||||
<div class="mt-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.intervals', 'Context Intervals (optional)') }}
|
||||
</label>
|
||||
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('admin.channels.form.addInterval', 'Add Interval') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
|
||||
<IntervalRow
|
||||
v-for="(iv, idx) in entry.intervals"
|
||||
:key="idx"
|
||||
:interval="iv"
|
||||
:mode="entry.billing_mode"
|
||||
@update="updateInterval(idx, $event)"
|
||||
@remove="removeInterval(idx)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Per-request mode: tiers -->
|
||||
<div v-else-if="entry.billing_mode === 'per_request'">
|
||||
<div class="flex items-center justify-between">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.requestTiers', 'Request Tiers') }}
|
||||
</label>
|
||||
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('admin.channels.form.addTier', 'Add Tier') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
|
||||
<IntervalRow
|
||||
v-for="(iv, idx) in entry.intervals"
|
||||
:key="idx"
|
||||
:interval="iv"
|
||||
:mode="entry.billing_mode"
|
||||
@update="updateInterval(idx, $event)"
|
||||
@remove="removeInterval(idx)"
|
||||
/>
|
||||
</div>
|
||||
<div v-else class="mt-2 rounded border border-dashed border-gray-300 p-3 text-center text-xs text-gray-400 dark:border-dark-500">
|
||||
{{ t('admin.channels.form.noTiersYet', 'No tiers. Add one to configure per-request pricing.') }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Image mode: tiers -->
|
||||
<div v-else-if="entry.billing_mode === 'image'">
|
||||
<div class="flex items-center justify-between">
|
||||
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.imageTiers', 'Image Tiers') }}
|
||||
</label>
|
||||
<button type="button" @click="addImageTier" class="text-xs text-primary-600 hover:text-primary-700">
|
||||
+ {{ t('admin.channels.form.addTier', 'Add Tier') }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
|
||||
<IntervalRow
|
||||
v-for="(iv, idx) in entry.intervals"
|
||||
:key="idx"
|
||||
:interval="iv"
|
||||
:mode="entry.billing_mode"
|
||||
@update="updateInterval(idx, $event)"
|
||||
@remove="removeInterval(idx)"
|
||||
/>
|
||||
</div>
|
||||
<div v-else>
|
||||
<!-- Legacy image_output_price fallback -->
|
||||
<div class="mt-2 grid grid-cols-2 gap-2 sm:grid-cols-4">
|
||||
<div>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.channels.form.imageOutputPrice', 'Image Output Price') }}
|
||||
</label>
|
||||
<input
|
||||
:value="entry.image_output_price"
|
||||
@input="emitField('image_output_price', ($event.target as HTMLInputElement).value)"
|
||||
type="number"
|
||||
step="any" min="0"
|
||||
class="input mt-1 text-sm"
|
||||
:placeholder="t('admin.channels.form.pricePlaceholder', 'Default')"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import IntervalRow from './IntervalRow.vue'
|
||||
import type { PricingFormEntry, IntervalFormEntry } from './types'
|
||||
import type { BillingMode } from '@/api/admin/channels'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
entry: PricingFormEntry
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
update: [entry: PricingFormEntry]
|
||||
remove: []
|
||||
}>()
|
||||
|
||||
const billingModeOptions = computed(() => [
|
||||
{ value: 'token', label: t('admin.channels.billingMode.token', 'Token') },
|
||||
{ value: 'per_request', label: t('admin.channels.billingMode.perRequest', 'Per Request') },
|
||||
{ value: 'image', label: t('admin.channels.billingMode.image', 'Image') }
|
||||
])
|
||||
|
||||
function emitField(field: keyof PricingFormEntry, value: string) {
|
||||
emit('update', { ...props.entry, [field]: value === '' ? null : value })
|
||||
}
|
||||
|
||||
function addInterval() {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
intervals.push({
|
||||
min_tokens: 0,
|
||||
max_tokens: null,
|
||||
tier_label: '',
|
||||
input_price: null,
|
||||
output_price: null,
|
||||
cache_write_price: null,
|
||||
cache_read_price: null,
|
||||
per_request_price: null,
|
||||
sort_order: intervals.length
|
||||
})
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
function addImageTier() {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
const labels = ['1K', '2K', '4K', 'HD']
|
||||
const nextLabel = labels[intervals.length] || ''
|
||||
intervals.push({
|
||||
min_tokens: 0,
|
||||
max_tokens: null,
|
||||
tier_label: nextLabel,
|
||||
input_price: null,
|
||||
output_price: null,
|
||||
cache_write_price: null,
|
||||
cache_read_price: null,
|
||||
per_request_price: null,
|
||||
sort_order: intervals.length
|
||||
})
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
function updateInterval(idx: number, updated: IntervalFormEntry) {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
intervals[idx] = updated
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
|
||||
function removeInterval(idx: number) {
|
||||
const intervals = [...(props.entry.intervals || [])]
|
||||
intervals.splice(idx, 1)
|
||||
emit('update', { ...props.entry, intervals })
|
||||
}
|
||||
</script>
|
||||
59
frontend/src/components/admin/channel/types.ts
Normal file
59
frontend/src/components/admin/channel/types.ts
Normal file
@ -0,0 +1,59 @@
|
||||
import type { BillingMode, PricingInterval } from '@/api/admin/channels'
|
||||
|
||||
export interface IntervalFormEntry {
|
||||
min_tokens: number
|
||||
max_tokens: number | null
|
||||
tier_label: string
|
||||
input_price: number | string | null
|
||||
output_price: number | string | null
|
||||
cache_write_price: number | string | null
|
||||
cache_read_price: number | string | null
|
||||
per_request_price: number | string | null
|
||||
sort_order: number
|
||||
}
|
||||
|
||||
export interface PricingFormEntry {
|
||||
modelsInput: string
|
||||
billing_mode: BillingMode
|
||||
input_price: number | string | null
|
||||
output_price: number | string | null
|
||||
cache_write_price: number | string | null
|
||||
cache_read_price: number | string | null
|
||||
per_request_price: number | string | null
|
||||
image_output_price: number | string | null
|
||||
intervals: IntervalFormEntry[]
|
||||
}
|
||||
|
||||
export function toNullableNumber(val: number | string | null | undefined): number | null {
|
||||
if (val === null || val === undefined || val === '') return null
|
||||
const num = Number(val)
|
||||
return isNaN(num) ? null : num
|
||||
}
|
||||
|
||||
export function apiIntervalsToForm(intervals: PricingInterval[]): IntervalFormEntry[] {
|
||||
return (intervals || []).map(iv => ({
|
||||
min_tokens: iv.min_tokens,
|
||||
max_tokens: iv.max_tokens,
|
||||
tier_label: iv.tier_label || '',
|
||||
input_price: iv.input_price,
|
||||
output_price: iv.output_price,
|
||||
cache_write_price: iv.cache_write_price,
|
||||
cache_read_price: iv.cache_read_price,
|
||||
per_request_price: iv.per_request_price,
|
||||
sort_order: iv.sort_order
|
||||
}))
|
||||
}
|
||||
|
||||
export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInterval[] {
|
||||
return (intervals || []).map(iv => ({
|
||||
min_tokens: iv.min_tokens,
|
||||
max_tokens: iv.max_tokens,
|
||||
tier_label: iv.tier_label,
|
||||
input_price: toNullableNumber(iv.input_price),
|
||||
output_price: toNullableNumber(iv.output_price),
|
||||
cache_write_price: toNullableNumber(iv.cache_write_price),
|
||||
cache_read_price: toNullableNumber(iv.cache_read_price),
|
||||
per_request_price: toNullableNumber(iv.per_request_price),
|
||||
sort_order: iv.sort_order
|
||||
}))
|
||||
}
|
||||
@ -287,6 +287,21 @@ const FolderIcon = {
|
||||
)
|
||||
}
|
||||
|
||||
const ChannelIcon = {
|
||||
render: () =>
|
||||
h(
|
||||
'svg',
|
||||
{ fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
|
||||
[
|
||||
h('path', {
|
||||
'stroke-linecap': 'round',
|
||||
'stroke-linejoin': 'round',
|
||||
d: 'M6.429 9.75L2.25 12l4.179 2.25m0-4.5l5.571 3 5.571-3m-11.142 0L2.25 7.5 12 2.25l9.75 5.25-4.179 2.25m0 0l4.179 2.25L12 17.25 2.25 12m15.321-2.25l4.179 2.25L12 17.25l-9.75-5.25'
|
||||
})
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
const CreditCardIcon = {
|
||||
render: () =>
|
||||
h(
|
||||
@ -568,6 +583,7 @@ const adminNavItems = computed((): NavItem[] => {
|
||||
: []),
|
||||
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
|
||||
{ path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon },
|
||||
{ path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon },
|
||||
|
||||
@ -278,6 +278,16 @@ const routes: RouteRecordRaw[] = [
|
||||
descriptionKey: 'admin.groups.description'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/channels',
|
||||
name: 'AdminChannels',
|
||||
component: () => import('@/views/admin/ChannelsView.vue'),
|
||||
meta: {
|
||||
requiresAuth: true,
|
||||
requiresAdmin: true,
|
||||
title: 'Channel Management'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/subscriptions',
|
||||
name: 'AdminSubscriptions',
|
||||
|
||||
628
frontend/src/views/admin/ChannelsView.vue
Normal file
628
frontend/src/views/admin/ChannelsView.vue
Normal file
@ -0,0 +1,628 @@
|
||||
<template>
|
||||
<AppLayout>
|
||||
<TablePageLayout>
|
||||
<template #filters>
|
||||
<div class="flex flex-col justify-between gap-4 lg:flex-row lg:items-start">
|
||||
<!-- Left: Search + Filters -->
|
||||
<div class="flex flex-1 flex-wrap items-center gap-3">
|
||||
<div class="relative w-full sm:w-64">
|
||||
<Icon
|
||||
name="search"
|
||||
size="md"
|
||||
class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-400 dark:text-gray-500"
|
||||
/>
|
||||
<input
|
||||
v-model="searchQuery"
|
||||
type="text"
|
||||
:placeholder="t('admin.channels.searchChannels', 'Search channels...')"
|
||||
class="input pl-10"
|
||||
@input="handleSearch"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Select
|
||||
v-model="filters.status"
|
||||
:options="statusFilterOptions"
|
||||
:placeholder="t('admin.channels.allStatus', 'All Status')"
|
||||
class="w-40"
|
||||
@change="loadChannels"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Right: Actions -->
|
||||
<div class="flex w-full flex-shrink-0 flex-wrap items-center justify-end gap-3 lg:w-auto">
|
||||
<button
|
||||
@click="loadChannels"
|
||||
:disabled="loading"
|
||||
class="btn btn-secondary"
|
||||
:title="t('common.refresh', 'Refresh')"
|
||||
>
|
||||
<Icon name="refresh" size="md" :class="loading ? 'animate-spin' : ''" />
|
||||
</button>
|
||||
<button @click="openCreateDialog" class="btn btn-primary">
|
||||
<Icon name="plus" size="md" class="mr-2" />
|
||||
{{ t('admin.channels.createChannel', 'Create Channel') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template #table>
|
||||
<DataTable :columns="columns" :data="channels" :loading="loading">
|
||||
<template #cell-name="{ value }">
|
||||
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
|
||||
</template>
|
||||
|
||||
<template #cell-description="{ value }">
|
||||
<span class="text-sm text-gray-600 dark:text-gray-400">{{ value || '-' }}</span>
|
||||
</template>
|
||||
|
||||
<template #cell-status="{ value }">
|
||||
<span
|
||||
:class="[
|
||||
'inline-flex items-center rounded-full px-2 py-0.5 text-xs font-medium',
|
||||
value === 'active'
|
||||
? 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
: 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
|
||||
]"
|
||||
>
|
||||
{{ value === 'active' ? t('admin.channels.statusActive', 'Active') : t('admin.channels.statusDisabled', 'Disabled') }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-group_count="{ row }">
|
||||
<span
|
||||
class="inline-flex items-center rounded bg-gray-100 px-2 py-0.5 text-xs font-medium text-gray-800 dark:bg-dark-600 dark:text-gray-300"
|
||||
>
|
||||
{{ (row.group_ids || []).length }}
|
||||
{{ t('admin.channels.groupsUnit', 'groups') }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-pricing_count="{ row }">
|
||||
<span
|
||||
class="inline-flex items-center rounded bg-gray-100 px-2 py-0.5 text-xs font-medium text-gray-800 dark:bg-dark-600 dark:text-gray-300"
|
||||
>
|
||||
{{ (row.model_pricing || []).length }}
|
||||
{{ t('admin.channels.pricingUnit', 'pricing rules') }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-created_at="{ value }">
|
||||
<span class="text-sm text-gray-600 dark:text-gray-400">
|
||||
{{ formatDate(value) }}
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-actions="{ row }">
|
||||
<div class="flex items-center gap-1">
|
||||
<button
|
||||
@click="openEditDialog(row)"
|
||||
class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-gray-100 hover:text-primary-600 dark:hover:bg-dark-700 dark:hover:text-primary-400"
|
||||
>
|
||||
<Icon name="edit" size="sm" />
|
||||
<span class="text-xs">{{ t('common.edit', 'Edit') }}</span>
|
||||
</button>
|
||||
<button
|
||||
@click="handleDelete(row)"
|
||||
class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20 dark:hover:text-red-400"
|
||||
>
|
||||
<Icon name="trash" size="sm" />
|
||||
<span class="text-xs">{{ t('common.delete', 'Delete') }}</span>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template #empty>
|
||||
<EmptyState
|
||||
:title="t('admin.channels.noChannelsYet', 'No Channels Yet')"
|
||||
:description="t('admin.channels.createFirstChannel', 'Create your first channel to manage model pricing')"
|
||||
:action-text="t('admin.channels.createChannel', 'Create Channel')"
|
||||
@action="openCreateDialog"
|
||||
/>
|
||||
</template>
|
||||
</DataTable>
|
||||
</template>
|
||||
|
||||
<template #pagination>
|
||||
<Pagination
|
||||
v-if="pagination.total > 0"
|
||||
:page="pagination.page"
|
||||
:total="pagination.total"
|
||||
:page-size="pagination.page_size"
|
||||
@update:page="handlePageChange"
|
||||
@update:pageSize="handlePageSizeChange"
|
||||
/>
|
||||
</template>
|
||||
</TablePageLayout>
|
||||
|
||||
<!-- Create/Edit Dialog -->
|
||||
<BaseDialog
|
||||
:show="showDialog"
|
||||
:title="editingChannel ? t('admin.channels.editChannel', 'Edit Channel') : t('admin.channels.createChannel', 'Create Channel')"
|
||||
width="extra-wide"
|
||||
@close="closeDialog"
|
||||
>
|
||||
<form id="channel-form" @submit.prevent="handleSubmit" class="space-y-5">
|
||||
<!-- Name -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.channels.form.name', 'Name') }}</label>
|
||||
<input
|
||||
v-model="form.name"
|
||||
type="text"
|
||||
required
|
||||
class="input"
|
||||
:placeholder="t('admin.channels.form.namePlaceholder', 'Enter channel name')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Description -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.channels.form.description', 'Description') }}</label>
|
||||
<textarea
|
||||
v-model="form.description"
|
||||
rows="2"
|
||||
class="input"
|
||||
:placeholder="t('admin.channels.form.descriptionPlaceholder', 'Optional description')"
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<!-- Status (edit only) -->
|
||||
<div v-if="editingChannel">
|
||||
<label class="input-label">{{ t('admin.channels.form.status', 'Status') }}</label>
|
||||
<Select v-model="form.status" :options="statusEditOptions" />
|
||||
</div>
|
||||
|
||||
<!-- Group Association -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.channels.form.groups', 'Associated Groups') }}</label>
|
||||
<div
|
||||
class="max-h-48 overflow-auto rounded-lg border border-gray-200 bg-white p-2 dark:border-dark-600 dark:bg-dark-800"
|
||||
>
|
||||
<div v-if="groupsLoading" class="py-4 text-center text-sm text-gray-500">
|
||||
{{ t('common.loading', 'Loading...') }}
|
||||
</div>
|
||||
<div v-else-if="allGroups.length === 0" class="py-4 text-center text-sm text-gray-500">
|
||||
{{ t('admin.channels.form.noGroupsAvailable', 'No groups available') }}
|
||||
</div>
|
||||
<label
|
||||
v-for="group in allGroups"
|
||||
:key="group.id"
|
||||
class="flex cursor-pointer items-center gap-2 rounded px-2 py-1.5 hover:bg-gray-50 dark:hover:bg-dark-700"
|
||||
:class="{ 'opacity-50': isGroupInOtherChannel(group.id) }"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.group_ids.includes(group.id)"
|
||||
:disabled="isGroupInOtherChannel(group.id)"
|
||||
class="h-4 w-4 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
@change="toggleGroup(group.id)"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ group.name }}</span>
|
||||
<span
|
||||
v-if="isGroupInOtherChannel(group.id)"
|
||||
class="ml-auto text-xs text-gray-400"
|
||||
>
|
||||
{{ getGroupInOtherChannelLabel(group.id) }}
|
||||
</span>
|
||||
<span
|
||||
v-if="group.platform"
|
||||
class="ml-auto text-xs text-gray-400 dark:text-gray-500"
|
||||
>
|
||||
{{ group.platform }}
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Pricing -->
|
||||
<div>
|
||||
<div class="mb-2 flex items-center justify-between">
|
||||
<label class="input-label mb-0">{{ t('admin.channels.form.modelPricing', 'Model Pricing') }}</label>
|
||||
<button type="button" @click="addPricingEntry" class="btn btn-secondary btn-sm">
|
||||
<Icon name="plus" size="sm" class="mr-1" />
|
||||
{{ t('common.add', 'Add') }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="form.model_pricing.length === 0"
|
||||
class="rounded-lg border border-dashed border-gray-300 p-4 text-center text-sm text-gray-500 dark:border-dark-500 dark:text-gray-400"
|
||||
>
|
||||
{{ t('admin.channels.form.noPricingRules', 'No pricing rules yet. Click "Add" to create one.') }}
|
||||
</div>
|
||||
|
||||
<div v-else class="space-y-3">
|
||||
<PricingEntryCard
|
||||
v-for="(entry, idx) in form.model_pricing"
|
||||
:key="idx"
|
||||
:entry="entry"
|
||||
@update="updatePricingEntry(idx, $event)"
|
||||
@remove="removePricingEntry(idx)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<template #footer>
|
||||
<div class="flex justify-end gap-3">
|
||||
<button @click="closeDialog" type="button" class="btn btn-secondary">
|
||||
{{ t('common.cancel', 'Cancel') }}
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
form="channel-form"
|
||||
:disabled="submitting"
|
||||
class="btn btn-primary"
|
||||
>
|
||||
{{ submitting
|
||||
? t('common.submitting', 'Submitting...')
|
||||
: editingChannel
|
||||
? t('common.update', 'Update')
|
||||
: t('common.create', 'Create')
|
||||
}}
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</BaseDialog>
|
||||
|
||||
<!-- Delete Confirmation -->
|
||||
<ConfirmDialog
|
||||
:show="showDeleteDialog"
|
||||
:title="t('admin.channels.deleteChannel', 'Delete Channel')"
|
||||
:message="deleteConfirmMessage"
|
||||
:confirm-text="t('common.delete', 'Delete')"
|
||||
:cancel-text="t('common.cancel', 'Cancel')"
|
||||
:danger="true"
|
||||
@confirm="confirmDelete"
|
||||
@cancel="showDeleteDialog = false"
|
||||
/>
|
||||
</AppLayout>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
||||
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
||||
import { toNullableNumber, apiIntervalsToForm, formIntervalsToAPI } from '@/components/admin/channel/types'
|
||||
import type { AdminGroup } from '@/types'
|
||||
import type { Column } from '@/components/common/types'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
import Pagination from '@/components/common/Pagination.vue'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import EmptyState from '@/components/common/EmptyState.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import PricingEntryCard from '@/components/admin/channel/PricingEntryCard.vue'
|
||||
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
// ── Table columns ──
|
||||
const columns = computed<Column[]>(() => [
|
||||
{ key: 'name', label: t('admin.channels.columns.name', 'Name'), sortable: true },
|
||||
{ key: 'description', label: t('admin.channels.columns.description', 'Description'), sortable: false },
|
||||
{ key: 'status', label: t('admin.channels.columns.status', 'Status'), sortable: true },
|
||||
{ key: 'group_count', label: t('admin.channels.columns.groups', 'Groups'), sortable: false },
|
||||
{ key: 'pricing_count', label: t('admin.channels.columns.pricing', 'Pricing'), sortable: false },
|
||||
{ key: 'created_at', label: t('admin.channels.columns.createdAt', 'Created'), sortable: true },
|
||||
{ key: 'actions', label: t('admin.channels.columns.actions', 'Actions'), sortable: false }
|
||||
])
|
||||
|
||||
const statusFilterOptions = computed(() => [
|
||||
{ value: '', label: t('admin.channels.allStatus', 'All Status') },
|
||||
{ value: 'active', label: t('admin.channels.statusActive', 'Active') },
|
||||
{ value: 'disabled', label: t('admin.channels.statusDisabled', 'Disabled') }
|
||||
])
|
||||
|
||||
const statusEditOptions = computed(() => [
|
||||
{ value: 'active', label: t('admin.channels.statusActive', 'Active') },
|
||||
{ value: 'disabled', label: t('admin.channels.statusDisabled', 'Disabled') }
|
||||
])
|
||||
|
||||
// ── State ──
|
||||
const channels = ref<Channel[]>([])
|
||||
const loading = ref(false)
|
||||
const searchQuery = ref('')
|
||||
const filters = reactive({ status: '' })
|
||||
const pagination = reactive({
|
||||
page: 1,
|
||||
page_size: getPersistedPageSize(),
|
||||
total: 0
|
||||
})
|
||||
|
||||
// Dialog state
|
||||
const showDialog = ref(false)
|
||||
const editingChannel = ref<Channel | null>(null)
|
||||
const submitting = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const deletingChannel = ref<Channel | null>(null)
|
||||
|
||||
// Groups
|
||||
const allGroups = ref<AdminGroup[]>([])
|
||||
const groupsLoading = ref(false)
|
||||
|
||||
// Form data
|
||||
const form = reactive({
|
||||
name: '',
|
||||
description: '',
|
||||
status: 'active',
|
||||
group_ids: [] as number[],
|
||||
model_pricing: [] as PricingFormEntry[]
|
||||
})
|
||||
|
||||
let abortController: AbortController | null = null
|
||||
|
||||
// ── Helpers ──
|
||||
function formatDate(value: string): string {
|
||||
if (!value) return '-'
|
||||
return new Date(value).toLocaleDateString()
|
||||
}
|
||||
|
||||
// ── Group helpers ──
|
||||
const groupToChannelMap = computed(() => {
|
||||
const map = new Map<number, Channel>()
|
||||
for (const ch of channels.value) {
|
||||
if (editingChannel.value && ch.id === editingChannel.value.id) continue
|
||||
for (const gid of ch.group_ids || []) {
|
||||
map.set(gid, ch)
|
||||
}
|
||||
}
|
||||
return map
|
||||
})
|
||||
|
||||
function isGroupInOtherChannel(groupId: number): boolean {
|
||||
return groupToChannelMap.value.has(groupId)
|
||||
}
|
||||
|
||||
function getGroupChannelName(groupId: number): string {
|
||||
return groupToChannelMap.value.get(groupId)?.name || ''
|
||||
}
|
||||
|
||||
function getGroupInOtherChannelLabel(groupId: number): string {
|
||||
const name = getGroupChannelName(groupId)
|
||||
return t('admin.channels.form.inOtherChannel', { name }, `In "${name}"`)
|
||||
}
|
||||
|
||||
const deleteConfirmMessage = computed(() => {
|
||||
const name = deletingChannel.value?.name || ''
|
||||
return t(
|
||||
'admin.channels.deleteConfirm',
|
||||
{ name },
|
||||
`Are you sure you want to delete channel "${name}"? This action cannot be undone.`
|
||||
)
|
||||
})
|
||||
|
||||
function toggleGroup(groupId: number) {
|
||||
const idx = form.group_ids.indexOf(groupId)
|
||||
if (idx >= 0) {
|
||||
form.group_ids.splice(idx, 1)
|
||||
} else {
|
||||
form.group_ids.push(groupId)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pricing helpers ──
|
||||
function addPricingEntry() {
|
||||
form.model_pricing.push({
|
||||
modelsInput: '',
|
||||
billing_mode: 'token',
|
||||
input_price: null,
|
||||
output_price: null,
|
||||
cache_write_price: null,
|
||||
cache_read_price: null,
|
||||
per_request_price: null,
|
||||
image_output_price: null,
|
||||
intervals: []
|
||||
})
|
||||
}
|
||||
|
||||
function updatePricingEntry(idx: number, updated: PricingFormEntry) {
|
||||
form.model_pricing[idx] = updated
|
||||
}
|
||||
|
||||
function removePricingEntry(idx: number) {
|
||||
form.model_pricing.splice(idx, 1)
|
||||
}
|
||||
|
||||
function formPricingToAPI(): ChannelModelPricing[] {
|
||||
return form.model_pricing
|
||||
.filter(e => e.modelsInput.trim())
|
||||
.map(e => ({
|
||||
models: e.modelsInput.split(',').map(m => m.trim()).filter(Boolean),
|
||||
billing_mode: e.billing_mode,
|
||||
input_price: toNullableNumber(e.input_price),
|
||||
output_price: toNullableNumber(e.output_price),
|
||||
cache_write_price: toNullableNumber(e.cache_write_price),
|
||||
cache_read_price: toNullableNumber(e.cache_read_price),
|
||||
image_output_price: toNullableNumber(e.image_output_price),
|
||||
intervals: formIntervalsToAPI(e.intervals || [])
|
||||
}))
|
||||
}
|
||||
|
||||
function apiPricingToForm(pricing: ChannelModelPricing[]): PricingFormEntry[] {
|
||||
return pricing.map(p => ({
|
||||
modelsInput: p.models.join(', '),
|
||||
billing_mode: p.billing_mode,
|
||||
input_price: p.input_price,
|
||||
output_price: p.output_price,
|
||||
cache_write_price: p.cache_write_price,
|
||||
cache_read_price: p.cache_read_price,
|
||||
per_request_price: null,
|
||||
image_output_price: p.image_output_price,
|
||||
intervals: apiIntervalsToForm(p.intervals || [])
|
||||
}))
|
||||
}
|
||||
|
||||
// ── Load data ──
|
||||
async function loadChannels() {
|
||||
if (abortController) abortController.abort()
|
||||
const ctrl = new AbortController()
|
||||
abortController = ctrl
|
||||
loading.value = true
|
||||
|
||||
try {
|
||||
const response = await adminAPI.channels.list(pagination.page, pagination.page_size, {
|
||||
status: filters.status || undefined,
|
||||
search: searchQuery.value || undefined
|
||||
}, { signal: ctrl.signal })
|
||||
|
||||
if (ctrl.signal.aborted || abortController !== ctrl) return
|
||||
channels.value = response.items || []
|
||||
pagination.total = response.total
|
||||
} catch (error: any) {
|
||||
if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
|
||||
appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
|
||||
console.error('Error loading channels:', error)
|
||||
} finally {
|
||||
if (abortController === ctrl) {
|
||||
loading.value = false
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function loadGroups() {
|
||||
groupsLoading.value = true
|
||||
try {
|
||||
allGroups.value = await adminAPI.groups.getAll()
|
||||
} catch (error) {
|
||||
console.error('Error loading groups:', error)
|
||||
} finally {
|
||||
groupsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
let searchTimeout: ReturnType<typeof setTimeout>
|
||||
function handleSearch() {
|
||||
clearTimeout(searchTimeout)
|
||||
searchTimeout = setTimeout(() => {
|
||||
pagination.page = 1
|
||||
loadChannels()
|
||||
}, 300)
|
||||
}
|
||||
|
||||
function handlePageChange(page: number) {
|
||||
pagination.page = page
|
||||
loadChannels()
|
||||
}
|
||||
|
||||
function handlePageSizeChange(pageSize: number) {
|
||||
pagination.page_size = pageSize
|
||||
pagination.page = 1
|
||||
loadChannels()
|
||||
}
|
||||
|
||||
// ── Dialog ──
|
||||
function resetForm() {
|
||||
form.name = ''
|
||||
form.description = ''
|
||||
form.status = 'active'
|
||||
form.group_ids = []
|
||||
form.model_pricing = []
|
||||
}
|
||||
|
||||
function openCreateDialog() {
|
||||
editingChannel.value = null
|
||||
resetForm()
|
||||
loadGroups()
|
||||
showDialog.value = true
|
||||
}
|
||||
|
||||
function openEditDialog(channel: Channel) {
|
||||
editingChannel.value = channel
|
||||
form.name = channel.name
|
||||
form.description = channel.description || ''
|
||||
form.status = channel.status
|
||||
form.group_ids = [...(channel.group_ids || [])]
|
||||
form.model_pricing = apiPricingToForm(channel.model_pricing || [])
|
||||
loadGroups()
|
||||
showDialog.value = true
|
||||
}
|
||||
|
||||
function closeDialog() {
|
||||
showDialog.value = false
|
||||
editingChannel.value = null
|
||||
resetForm()
|
||||
}
|
||||
|
||||
async function handleSubmit() {
|
||||
if (submitting.value) return
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.channels.nameRequired', 'Please enter a channel name'))
|
||||
return
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
if (editingChannel.value) {
|
||||
const req: UpdateChannelRequest = {
|
||||
name: form.name.trim(),
|
||||
description: form.description.trim() || undefined,
|
||||
status: form.status,
|
||||
group_ids: form.group_ids,
|
||||
model_pricing: formPricingToAPI()
|
||||
}
|
||||
await adminAPI.channels.update(editingChannel.value.id, req)
|
||||
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
|
||||
} else {
|
||||
const req: CreateChannelRequest = {
|
||||
name: form.name.trim(),
|
||||
description: form.description.trim() || undefined,
|
||||
group_ids: form.group_ids,
|
||||
model_pricing: formPricingToAPI()
|
||||
}
|
||||
await adminAPI.channels.create(req)
|
||||
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
|
||||
}
|
||||
closeDialog()
|
||||
loadChannels()
|
||||
} catch (error: any) {
|
||||
const msg = error.response?.data?.detail || (editingChannel.value
|
||||
? t('admin.channels.updateError', 'Failed to update channel')
|
||||
: t('admin.channels.createError', 'Failed to create channel'))
|
||||
appStore.showError(msg)
|
||||
console.error('Error saving channel:', error)
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ── Delete ──
|
||||
function handleDelete(channel: Channel) {
|
||||
deletingChannel.value = channel
|
||||
showDeleteDialog.value = true
|
||||
}
|
||||
|
||||
async function confirmDelete() {
|
||||
if (!deletingChannel.value) return
|
||||
|
||||
try {
|
||||
await adminAPI.channels.remove(deletingChannel.value.id)
|
||||
appStore.showSuccess(t('admin.channels.deleteSuccess', 'Channel deleted'))
|
||||
showDeleteDialog.value = false
|
||||
deletingChannel.value = null
|
||||
loadChannels()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
|
||||
console.error('Error deleting channel:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Lifecycle ──
|
||||
onMounted(() => {
|
||||
loadChannels()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
clearTimeout(searchTimeout)
|
||||
abortController?.abort()
|
||||
})
|
||||
</script>
|
||||
Loading…
x
Reference in New Issue
Block a user