feat(channels): add custom account stats pricing rules
Allow channels to configure independent model pricing for account statistics cost calculation, decoupled from user billing. Backend: - Migration 101: channels.apply_pricing_to_account_stats toggle, channel_account_stats_pricing_rules/model_pricing tables, usage_logs.account_stats_cost column - resolveAccountStatsCost: match rules by group/account, then channel pricing, fallback to original formula when unconfigured - Integrate into both GatewayService.recordUsageCore and OpenAIGatewayService.RecordUsage - Update 8 account stats SQL queries to use COALESCE(account_stats_cost, total_cost) * account_rate_multiplier - 23 unit tests for matching, pricing lookup, and cost calculation Frontend: - Channel edit dialog: toggle + custom rules UI with group/account multi-select and pricing entry cards - API types and i18n (zh/en)
This commit is contained in:
parent
7fad9f604f
commit
7535e312e0
@ -26,28 +26,30 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
|
|||||||
// --- Request / Response types ---
|
// --- Request / Response types ---
|
||||||
|
|
||||||
type createChannelRequest struct {
|
type createChannelRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=100"`
|
Name string `json:"name" binding:"required,max=100"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||||
RestrictModels bool `json:"restrict_models"`
|
RestrictModels bool `json:"restrict_models"`
|
||||||
Features string `json:"features"`
|
Features string `json:"features"`
|
||||||
FeaturesConfig map[string]any `json:"features_config"`
|
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||||
|
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type updateChannelRequest struct {
|
type updateChannelRequest struct {
|
||||||
Name string `json:"name" binding:"omitempty,max=100"`
|
Name string `json:"name" binding:"omitempty,max=100"`
|
||||||
Description *string `json:"description"`
|
Description *string `json:"description"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||||
RestrictModels *bool `json:"restrict_models"`
|
RestrictModels *bool `json:"restrict_models"`
|
||||||
Features *string `json:"features"`
|
Features *string `json:"features"`
|
||||||
FeaturesConfig map[string]any `json:"features_config"`
|
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
|
||||||
|
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type channelModelPricingRequest struct {
|
type channelModelPricingRequest struct {
|
||||||
@ -75,20 +77,28 @@ type pricingIntervalRequest struct {
|
|||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountStatsPricingRuleRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
Pricing []channelModelPricingRequest `json:"pricing"`
|
||||||
|
}
|
||||||
|
|
||||||
type channelResponse struct {
|
type channelResponse struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
BillingModelSource string `json:"billing_model_source"`
|
BillingModelSource string `json:"billing_model_source"`
|
||||||
RestrictModels bool `json:"restrict_models"`
|
RestrictModels bool `json:"restrict_models"`
|
||||||
Features string `json:"features"`
|
Features string `json:"features"`
|
||||||
FeaturesConfig map[string]any `json:"features_config"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||||
CreatedAt string `json:"created_at"`
|
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
|
||||||
UpdatedAt string `json:"updated_at"`
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type channelModelPricingResponse struct {
|
type channelModelPricingResponse struct {
|
||||||
@ -118,6 +128,14 @@ type pricingIntervalResponse struct {
|
|||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountStatsPricingRuleResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
Pricing []channelModelPricingResponse `json:"pricing"`
|
||||||
|
}
|
||||||
|
|
||||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||||
if ch == nil {
|
if ch == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -129,7 +147,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
|||||||
Status: ch.Status,
|
Status: ch.Status,
|
||||||
RestrictModels: ch.RestrictModels,
|
RestrictModels: ch.RestrictModels,
|
||||||
Features: ch.Features,
|
Features: ch.Features,
|
||||||
FeaturesConfig: ch.FeaturesConfig,
|
|
||||||
GroupIDs: ch.GroupIDs,
|
GroupIDs: ch.GroupIDs,
|
||||||
ModelMapping: ch.ModelMapping,
|
ModelMapping: ch.ModelMapping,
|
||||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
@ -150,6 +167,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
|||||||
for _, p := range ch.ModelPricing {
|
for _, p := range ch.ModelPricing {
|
||||||
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
|
||||||
|
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
|
||||||
|
for _, rule := range ch.AccountStatsPricingRules {
|
||||||
|
ruleResp := accountStatsPricingRuleResponse{
|
||||||
|
ID: rule.ID,
|
||||||
|
Name: rule.Name,
|
||||||
|
GroupIDs: rule.GroupIDs,
|
||||||
|
AccountIDs: rule.AccountIDs,
|
||||||
|
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
|
||||||
|
}
|
||||||
|
if ruleResp.GroupIDs == nil {
|
||||||
|
ruleResp.GroupIDs = []int64{}
|
||||||
|
}
|
||||||
|
if ruleResp.AccountIDs == nil {
|
||||||
|
ruleResp.AccountIDs = []int64{}
|
||||||
|
}
|
||||||
|
for i := range rule.Pricing {
|
||||||
|
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
|
||||||
|
}
|
||||||
|
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
|
||||||
|
}
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,6 +281,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
|
||||||
|
return service.AccountStatsPricingRule{
|
||||||
|
Name: r.Name,
|
||||||
|
GroupIDs: r.GroupIDs,
|
||||||
|
AccountIDs: r.AccountIDs,
|
||||||
|
Pricing: pricingRequestToService(r.Pricing),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Handlers ---
|
// --- Handlers ---
|
||||||
|
|
||||||
// List handles listing channels with pagination
|
// List handles listing channels with pagination
|
||||||
@ -300,16 +349,24 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
|
|
||||||
pricing := pricingRequestToService(req.ModelPricing)
|
pricing := pricingRequestToService(req.ModelPricing)
|
||||||
|
|
||||||
|
var statsRules []service.AccountStatsPricingRule
|
||||||
|
for i, r := range req.AccountStatsPricingRules {
|
||||||
|
rule := accountStatsPricingRuleRequestToService(r)
|
||||||
|
rule.SortOrder = i
|
||||||
|
statsRules = append(statsRules, rule)
|
||||||
|
}
|
||||||
|
|
||||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ModelPricing: pricing,
|
ModelPricing: pricing,
|
||||||
ModelMapping: req.ModelMapping,
|
ModelMapping: req.ModelMapping,
|
||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
Features: req.Features,
|
Features: req.Features,
|
||||||
FeaturesConfig: req.FeaturesConfig,
|
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||||
|
AccountStatsPricingRules: statsRules,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@ -335,20 +392,29 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
input := &service.UpdateChannelInput{
|
input := &service.UpdateChannelInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ModelMapping: req.ModelMapping,
|
ModelMapping: req.ModelMapping,
|
||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
Features: req.Features,
|
Features: req.Features,
|
||||||
FeaturesConfig: req.FeaturesConfig,
|
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
input.ModelPricing = &pricing
|
input.ModelPricing = &pricing
|
||||||
}
|
}
|
||||||
|
if req.AccountStatsPricingRules != nil {
|
||||||
|
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
|
||||||
|
for i, r := range *req.AccountStatsPricingRules {
|
||||||
|
rule := accountStatsPricingRuleRequestToService(r)
|
||||||
|
rule.SortOrder = i
|
||||||
|
statsRules = append(statsRules, rule)
|
||||||
|
}
|
||||||
|
input.AccountStatsPricingRules = &statsRules
|
||||||
|
}
|
||||||
|
|
||||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -41,14 +41,10 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.QueryRowContext(ctx,
|
err = tx.QueryRowContext(ctx,
|
||||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
RETURNING id, created_at, updated_at`,
|
RETURNING id, created_at, updated_at`,
|
||||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
|
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats,
|
||||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isUniqueViolation(err) {
|
if isUniqueViolation(err) {
|
||||||
@ -71,17 +67,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置账号统计定价规则
|
||||||
|
if len(channel.AccountStatsPricingRules) > 0 {
|
||||||
|
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||||
ch := &service.Channel{}
|
ch := &service.Channel{}
|
||||||
var modelMappingJSON, featuresConfigJSON []byte
|
var modelMappingJSON []byte
|
||||||
err := r.db.QueryRowContext(ctx,
|
err := r.db.QueryRowContext(ctx,
|
||||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
|
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at
|
||||||
FROM channels WHERE id = $1`, id,
|
FROM channels WHERE id = $1`, id,
|
||||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
|
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, service.ErrChannelNotFound
|
return nil, service.ErrChannelNotFound
|
||||||
}
|
}
|
||||||
@ -89,7 +92,6 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
|||||||
return nil, fmt.Errorf("get channel: %w", err)
|
return nil, fmt.Errorf("get channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
|
||||||
|
|
||||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -103,6 +105,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
|||||||
}
|
}
|
||||||
ch.ModelPricing = pricing
|
ch.ModelPricing = pricing
|
||||||
|
|
||||||
|
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ch.AccountStatsPricingRules = statsPricingRules
|
||||||
|
|
||||||
return ch, nil
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,14 +120,10 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
result, err := tx.ExecContext(ctx,
|
result, err := tx.ExecContext(ctx,
|
||||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
|
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW()
|
||||||
WHERE id = $9`,
|
WHERE id = $9`,
|
||||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
|
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isUniqueViolation(err) {
|
if isUniqueViolation(err) {
|
||||||
@ -146,6 +150,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新账号统计定价规则
|
||||||
|
if channel.AccountStatsPricingRules != nil {
|
||||||
|
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -196,7 +207,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
|
|
||||||
// 查询 channel 列表
|
// 查询 channel 列表
|
||||||
dataQuery := fmt.Sprintf(
|
dataQuery := fmt.Sprintf(
|
||||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
|
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
|
||||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||||
)
|
)
|
||||||
@ -212,12 +223,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
var channelIDs []int64
|
var channelIDs []int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch service.Channel
|
var ch service.Channel
|
||||||
var modelMappingJSON, featuresConfigJSON []byte
|
var modelMappingJSON []byte
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
channelIDs = append(channelIDs, ch.ID)
|
channelIDs = append(channelIDs, ch.ID)
|
||||||
}
|
}
|
||||||
@ -235,9 +245,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||||
|
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,7 +298,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
|
|||||||
|
|
||||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||||
rows, err := r.db.QueryContext(ctx,
|
rows, err := r.db.QueryContext(ctx,
|
||||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
|
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query all channels: %w", err)
|
return nil, fmt.Errorf("query all channels: %w", err)
|
||||||
@ -294,12 +309,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
|||||||
var channelIDs []int64
|
var channelIDs []int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch service.Channel
|
var ch service.Channel
|
||||||
var modelMappingJSON, featuresConfigJSON []byte
|
var modelMappingJSON []byte
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||||
return nil, fmt.Errorf("scan channel: %w", err)
|
return nil, fmt.Errorf("scan channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
channelIDs = append(channelIDs, ch.ID)
|
channelIDs = append(channelIDs, ch.ID)
|
||||||
}
|
}
|
||||||
@ -323,9 +337,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 批量加载账号统计定价规则
|
||||||
|
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||||
|
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||||
}
|
}
|
||||||
|
|
||||||
return channels, nil
|
return channels, nil
|
||||||
@ -467,28 +488,6 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
|
||||||
if len(m) == 0 {
|
|
||||||
return []byte("{}"), nil
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(m)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal features_config: %w", err)
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
|
||||||
if len(data) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var m map[string]any
|
|
||||||
if err := json.Unmarshal(data, &m); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||||
if len(groupIDs) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
|
|||||||
@ -0,0 +1,170 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- 账号统计定价规则 ---
|
||||||
|
|
||||||
|
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
|
||||||
|
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
|
||||||
|
// 1. 查询规则
|
||||||
|
rows, err := r.db.QueryContext(ctx,
|
||||||
|
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
|
||||||
|
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
|
||||||
|
pq.Array(channelIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var allRules []service.AccountStatsPricingRule
|
||||||
|
var ruleIDs []int64
|
||||||
|
for rows.Next() {
|
||||||
|
var rule service.AccountStatsPricingRule
|
||||||
|
if err := rows.Scan(
|
||||||
|
&rule.ID, &rule.ChannelID, &rule.Name,
|
||||||
|
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
|
||||||
|
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
|
||||||
|
}
|
||||||
|
ruleIDs = append(ruleIDs, rule.ID)
|
||||||
|
allRules = append(allRules, rule)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 批量加载规则的模型定价
|
||||||
|
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 按 channelID 分组并关联定价
|
||||||
|
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
|
||||||
|
for i := range allRules {
|
||||||
|
allRules[i].Pricing = pricingMap[allRules[i].ID]
|
||||||
|
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
|
||||||
|
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||||
|
if len(ruleIDs) == 0 {
|
||||||
|
return make(map[int64][]service.ChannelModelPricing), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx,
|
||||||
|
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
|
||||||
|
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||||
|
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
|
||||||
|
pq.Array(ruleIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
|
||||||
|
for rows.Next() {
|
||||||
|
var p service.ChannelModelPricing
|
||||||
|
var ruleID int64
|
||||||
|
var modelsJSON []byte
|
||||||
|
if err := rows.Scan(
|
||||||
|
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||||
|
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||||
|
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||||
|
p.Models = []string{}
|
||||||
|
}
|
||||||
|
pricingMap[ruleID] = append(pricingMap[ruleID], p)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
return pricingMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
|
||||||
|
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
|
||||||
|
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result[channelID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
|
||||||
|
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
|
||||||
|
// CASCADE 会自动删除关联的 model_pricing
|
||||||
|
if _, err := tx.ExecContext(ctx,
|
||||||
|
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("delete old account stats pricing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range rules {
|
||||||
|
rules[i].ChannelID = channelID
|
||||||
|
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
|
||||||
|
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
|
||||||
|
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
|
||||||
|
err := tx.QueryRowContext(ctx,
|
||||||
|
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
|
||||||
|
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
|
||||||
|
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
|
||||||
|
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := range rule.Pricing {
|
||||||
|
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
|
||||||
|
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
|
||||||
|
modelsJSON, err := json.Marshal(pricing.Models)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal models: %w", err)
|
||||||
|
}
|
||||||
|
billingMode := pricing.BillingMode
|
||||||
|
if billingMode == "" {
|
||||||
|
billingMode = service.BillingModeToken
|
||||||
|
}
|
||||||
|
platform := pricing.Platform
|
||||||
|
err = tx.QueryRowContext(ctx,
|
||||||
|
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||||
|
ruleID, platform, modelsJSON, billingMode,
|
||||||
|
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||||
|
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||||
|
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -28,7 +28,7 @@ import (
|
|||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
|
||||||
|
|
||||||
// usageLogInsertArgTypes must stay in the same order as:
|
// usageLogInsertArgTypes must stay in the same order as:
|
||||||
// 1. prepareUsageLogInsert().args
|
// 1. prepareUsageLogInsert().args
|
||||||
@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
|
|||||||
"text", // model_mapping_chain
|
"text", // model_mapping_chain
|
||||||
"text", // billing_tier
|
"text", // billing_tier
|
||||||
"text", // billing_mode
|
"text", // billing_mode
|
||||||
|
"numeric", // account_stats_cost
|
||||||
"timestamptz", // created_at
|
"timestamptz", // created_at
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5, $6, $7,
|
$1, $2, $3, $4, $5, $6, $7,
|
||||||
@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
$10, $11, $12, $13,
|
$10, $11, $12, $13,
|
||||||
$14, $15, $16, $17,
|
$14, $15, $16, $17,
|
||||||
$18, $19, $20, $21, $22, $23,
|
$18, $19, $20, $21, $22, $23,
|
||||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(preparedList)*45)
|
args := make([]any, 0, len(preparedList)*46)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, prepared := range preparedList {
|
for idx, prepared := range preparedList {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5, $6, $7,
|
$1, $2, $3, $4, $5, $6, $7,
|
||||||
@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
$10, $11, $12, $13,
|
$10, $11, $12, $13,
|
||||||
$14, $15, $16, $17,
|
$14, $15, $16, $17,
|
||||||
$18, $19, $20, $21, $22, $23,
|
$18, $19, $20, $21, $22, $23,
|
||||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
`, prepared.args...)
|
`, prepared.args...)
|
||||||
@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
modelMappingChain,
|
modelMappingChain,
|
||||||
billingTier,
|
billingTier,
|
||||||
billingMode,
|
billingMode,
|
||||||
|
log.AccountStatsCost, // account_stats_cost
|
||||||
createdAt,
|
createdAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
|
|||||||
account_id,
|
account_id,
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
|
|||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
modelExpr := resolveModelDimensionExpression(source)
|
modelExpr := resolveModelDimensionExpression(source)
|
||||||
|
|
||||||
@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
%s
|
%s
|
||||||
@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
|
|||||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
|
|||||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||||
@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
modelMappingChain sql.NullString
|
modelMappingChain sql.NullString
|
||||||
billingTier sql.NullString
|
billingTier sql.NullString
|
||||||
billingMode sql.NullString
|
billingMode sql.NullString
|
||||||
|
accountStatsCost sql.NullFloat64
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&modelMappingChain,
|
&modelMappingChain,
|
||||||
&billingTier,
|
&billingTier,
|
||||||
&billingMode,
|
&billingMode,
|
||||||
|
&accountStatsCost,
|
||||||
&createdAt,
|
&createdAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if billingMode.Valid {
|
if billingMode.Valid {
|
||||||
log.BillingMode = &billingMode.String
|
log.BillingMode = &billingMode.String
|
||||||
}
|
}
|
||||||
|
if accountStatsCost.Valid {
|
||||||
|
log.AccountStatsCost = &accountStatsCost.Float64
|
||||||
|
}
|
||||||
|
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // model_mapping_chain
|
sqlmock.AnyArg(), // model_mapping_chain
|
||||||
sqlmock.AnyArg(), // billing_tier
|
sqlmock.AnyArg(), // billing_tier
|
||||||
sqlmock.AnyArg(), // billing_mode
|
sqlmock.AnyArg(), // billing_mode
|
||||||
|
sqlmock.AnyArg(), // account_stats_cost
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||||
@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // model_mapping_chain
|
sqlmock.AnyArg(), // model_mapping_chain
|
||||||
sqlmock.AnyArg(), // billing_tier
|
sqlmock.AnyArg(), // billing_tier
|
||||||
sqlmock.AnyArg(), // billing_mode
|
sqlmock.AnyArg(), // billing_mode
|
||||||
|
sqlmock.AnyArg(), // account_stats_cost
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||||
@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
sql.NullInt64{}, // channel_id
|
sql.NullInt64{}, // channel_id
|
||||||
sql.NullString{}, // model_mapping_chain
|
sql.NullString{}, // model_mapping_chain
|
||||||
sql.NullString{}, // billing_tier
|
sql.NullString{}, // billing_tier
|
||||||
sql.NullString{}, // billing_mode
|
sql.NullString{}, // billing_mode
|
||||||
|
sql.NullFloat64{}, // account_stats_cost
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
sql.NullInt64{}, // channel_id
|
sql.NullInt64{}, // channel_id
|
||||||
sql.NullString{}, // model_mapping_chain
|
sql.NullString{}, // model_mapping_chain
|
||||||
sql.NullString{}, // billing_tier
|
sql.NullString{}, // billing_tier
|
||||||
sql.NullString{}, // billing_mode
|
sql.NullString{}, // billing_mode
|
||||||
|
sql.NullFloat64{}, // account_stats_cost
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
sql.NullInt64{}, // channel_id
|
sql.NullInt64{}, // channel_id
|
||||||
sql.NullString{}, // model_mapping_chain
|
sql.NullString{}, // model_mapping_chain
|
||||||
sql.NullString{}, // billing_tier
|
sql.NullString{}, // billing_tier
|
||||||
sql.NullString{}, // billing_mode
|
sql.NullString{}, // billing_mode
|
||||||
|
sql.NullFloat64{}, // account_stats_cost
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
192
backend/internal/service/account_stats_pricing.go
Normal file
192
backend/internal/service/account_stats_pricing.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resolveAccountStatsCost 计算账号统计定价费用。
|
||||||
|
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||||||
|
//
|
||||||
|
// 匹配优先级(先命中为准):
|
||||||
|
// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
|
||||||
|
// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
|
||||||
|
// 3. nil → 走默认公式
|
||||||
|
func resolveAccountStatsCost(
|
||||||
|
ctx context.Context,
|
||||||
|
channelService *ChannelService,
|
||||||
|
billingService *BillingService,
|
||||||
|
accountID int64,
|
||||||
|
groupID int64,
|
||||||
|
billingModel string,
|
||||||
|
tokens UsageTokens,
|
||||||
|
requestCount int,
|
||||||
|
serviceTier string,
|
||||||
|
) *float64 {
|
||||||
|
if channelService == nil || billingService == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||||||
|
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||||||
|
modelLower := strings.ToLower(billingModel)
|
||||||
|
|
||||||
|
// 优先级 1:自定义规则
|
||||||
|
if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先级 2:渠道已有模型定价
|
||||||
|
return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||||||
|
func tryCustomRules(
|
||||||
|
channel *Channel, accountID, groupID int64,
|
||||||
|
platform, modelLower string, tokens UsageTokens, requestCount int,
|
||||||
|
) *float64 {
|
||||||
|
for _, rule := range channel.AccountStatsPricingRules {
|
||||||
|
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||||||
|
if pricing == nil {
|
||||||
|
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||||||
|
}
|
||||||
|
return calculateStatsCost(pricing, tokens, requestCount)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
|
||||||
|
func tryChannelPricing(
|
||||||
|
ctx context.Context, channelService *ChannelService,
|
||||||
|
groupID int64, billingModel string, tokens UsageTokens, requestCount int,
|
||||||
|
) *float64 {
|
||||||
|
pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
|
||||||
|
if pricing == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return calculateStatsCost(pricing, tokens, requestCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||||||
|
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||||||
|
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||||||
|
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||||||
|
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, id := range rule.AccountIDs {
|
||||||
|
if id == accountID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, id := range rule.GroupIDs {
|
||||||
|
if id == groupID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wildcardMatch 通配符匹配候选项(用于排序)
|
||||||
|
type wildcardMatch struct {
|
||||||
|
prefixLen int
|
||||||
|
pricing *ChannelModelPricing
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||||||
|
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||||||
|
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||||||
|
// 精确匹配优先
|
||||||
|
for i := range pricingList {
|
||||||
|
p := &pricingList[i]
|
||||||
|
if !isPlatformMatch(platform, p.Platform) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, m := range p.Models {
|
||||||
|
if strings.ToLower(m) == modelLower {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||||||
|
var matches []wildcardMatch
|
||||||
|
for i := range pricingList {
|
||||||
|
p := &pricingList[i]
|
||||||
|
if !isPlatformMatch(platform, p.Platform) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, m := range p.Models {
|
||||||
|
ml := strings.ToLower(m)
|
||||||
|
if !strings.HasSuffix(ml, "*") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prefix := strings.TrimSuffix(ml, "*")
|
||||||
|
if strings.HasPrefix(modelLower, prefix) {
|
||||||
|
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sort.Slice(matches, func(i, j int) bool {
|
||||||
|
return matches[i].prefixLen > matches[j].prefixLen
|
||||||
|
})
|
||||||
|
return matches[0].pricing
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||||||
|
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||||||
|
if queryPlatform == "" || pricingPlatform == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return queryPlatform == pricingPlatform
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||||||
|
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||||||
|
if pricing == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch pricing.BillingMode {
|
||||||
|
case BillingModePerRequest, BillingModeImage:
|
||||||
|
return calculatePerRequestStatsCost(pricing, requestCount)
|
||||||
|
default:
|
||||||
|
return calculateTokenStatsCost(pricing, tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculatePerRequestStatsCost 按次/图片计费。
|
||||||
|
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||||||
|
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateTokenStatsCost Token 计费。
|
||||||
|
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||||||
|
deref := func(p *float64) float64 {
|
||||||
|
if p == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *p
|
||||||
|
}
|
||||||
|
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
||||||
|
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
||||||
|
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
||||||
|
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
||||||
|
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
||||||
|
if cost == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
430
backend/internal/service/account_stats_pricing_test.go
Normal file
430
backend/internal/service/account_stats_pricing_test.go
Normal file
@ -0,0 +1,430 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// matchAccountStatsRule
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{}
|
||||||
|
require.False(t, matchAccountStatsRule(rule, 1, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 999, 20))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{
|
||||||
|
AccountIDs: []int64{1, 2},
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{
|
||||||
|
AccountIDs: []int64{1, 2},
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 999, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{
|
||||||
|
AccountIDs: []int64{1, 2},
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
}
|
||||||
|
require.False(t, matchAccountStatsRule(rule, 999, 999))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// findPricingForModel
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestFindPricingForModel(t *testing.T) {
|
||||||
|
exactPricing := ChannelModelPricing{
|
||||||
|
ID: 1,
|
||||||
|
Models: []string{"claude-opus-4"},
|
||||||
|
}
|
||||||
|
wildcardPricing := ChannelModelPricing{
|
||||||
|
ID: 2,
|
||||||
|
Models: []string{"claude-*"},
|
||||||
|
}
|
||||||
|
platformPricing := ChannelModelPricing{
|
||||||
|
ID: 3,
|
||||||
|
Platform: "openai",
|
||||||
|
Models: []string{"gpt-4o"},
|
||||||
|
}
|
||||||
|
emptyPlatformPricing := ChannelModelPricing{
|
||||||
|
ID: 4,
|
||||||
|
Models: []string{"gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
list []ChannelModelPricing
|
||||||
|
platform string
|
||||||
|
model string
|
||||||
|
wantID int64
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
list: []ChannelModelPricing{exactPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact match case insensitive",
|
||||||
|
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
|
||||||
|
platform: "",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
list: []ChannelModelPricing{wildcardPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact match takes priority over wildcard",
|
||||||
|
list: []ChannelModelPricing{wildcardPricing, exactPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "platform mismatch skipped",
|
||||||
|
list: []ChannelModelPricing{platformPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "gpt-4o",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty platform in pricing matches any",
|
||||||
|
list: []ChannelModelPricing{emptyPlatformPricing},
|
||||||
|
platform: "gemini",
|
||||||
|
model: "gemini-2.5-pro",
|
||||||
|
wantID: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty platform in query matches any pricing platform",
|
||||||
|
list: []ChannelModelPricing{platformPricing},
|
||||||
|
platform: "",
|
||||||
|
model: "gpt-4o",
|
||||||
|
wantID: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match at all",
|
||||||
|
list: []ChannelModelPricing{exactPricing, wildcardPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "gpt-4o",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list returns nil",
|
||||||
|
list: nil,
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "longer wildcard prefix wins over shorter",
|
||||||
|
list: []ChannelModelPricing{
|
||||||
|
{ID: 10, Models: []string{"claude-*"}},
|
||||||
|
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||||
|
},
|
||||||
|
platform: "",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "shorter wildcard used when longer does not match",
|
||||||
|
list: []ChannelModelPricing{
|
||||||
|
{ID: 10, Models: []string{"claude-*"}},
|
||||||
|
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||||
|
},
|
||||||
|
platform: "",
|
||||||
|
model: "claude-sonnet-4",
|
||||||
|
wantID: 10, // only "claude-*" matches
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := findPricingForModel(tt.list, tt.platform, tt.model)
|
||||||
|
if tt.wantNil {
|
||||||
|
require.Nil(t, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, tt.wantID, result.ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// calculateStatsCost
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_NilPricing(t *testing.T) {
|
||||||
|
result := calculateStatsCost(nil, UsageTokens{}, 1)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||||
|
require.InDelta(t, 0.2, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
CacheWritePrice: testPtrFloat64(0.003),
|
||||||
|
CacheReadPrice: testPtrFloat64(0.0005),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheCreationTokens: 200,
|
||||||
|
CacheReadTokens: 300,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
|
||||||
|
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||||
|
require.InDelta(t, 0.95, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
ImageOutputPrice: testPtrFloat64(0.01),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
ImageOutputTokens: 10,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
|
||||||
|
require.InDelta(t, 0.3, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheCreationTokens: 200,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// Only input contributes: 100*0.001 = 0.1
|
||||||
|
require.InDelta(t, 0.1, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{} // all zeros
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
// totalCost == 0 → returns nil (does not override, falls back to default formula)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
PerRequestPrice: testPtrFloat64(0.05),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 3)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 0.05 * 3 = 0.15
|
||||||
|
require.InDelta(t, 0.15, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
// PerRequestPrice is nil
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
PerRequestPrice: testPtrFloat64(0),
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||||
|
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeImage,
|
||||||
|
PerRequestPrice: testPtrFloat64(0.10),
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 2)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 0.10 * 2 = 0.20
|
||||||
|
require.InDelta(t, 0.20, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeImage,
|
||||||
|
// PerRequestPrice is nil
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
|
||||||
|
// BillingMode is empty string (default) → falls into token billing
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.InDelta(t, 0.2, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// tryCustomRules — 多规则顺序测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
|
||||||
|
require.InDelta(t, 2.0, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
AccountIDs: []int64{888}, // 不匹配
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1}, // 匹配
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
|
||||||
|
require.InDelta(t, 5.0, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
AccountIDs: []int64{888},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.Nil(t, result) // 账号和分组都不匹配
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
|
||||||
|
}
|
||||||
@ -49,21 +49,25 @@ type Channel struct {
|
|||||||
ModelPricing []ChannelModelPricing
|
ModelPricing []ChannelModelPricing
|
||||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||||
ModelMapping map[string]map[string]string
|
ModelMapping map[string]map[string]string
|
||||||
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
|
||||||
FeaturesConfig map[string]any
|
// 账号统计定价
|
||||||
|
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
|
||||||
|
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
// AccountStatsPricingRule 账号统计定价规则
|
||||||
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
|
||||||
if c == nil || c.FeaturesConfig == nil {
|
// 多条规则按 SortOrder 排序,先命中为准。
|
||||||
return false
|
type AccountStatsPricingRule struct {
|
||||||
}
|
ID int64
|
||||||
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
ChannelID int64
|
||||||
if !ok {
|
Name string
|
||||||
return false
|
GroupIDs []int64
|
||||||
}
|
AccountIDs []int64
|
||||||
enabled, ok := wse[platform].(bool)
|
SortOrder int
|
||||||
return ok && enabled
|
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChannelModelPricing 渠道模型定价条目
|
// ChannelModelPricing 渠道模型定价条目
|
||||||
@ -192,6 +196,26 @@ func (c *Channel) Clone() *Channel {
|
|||||||
cp.ModelMapping[platform] = inner
|
cp.ModelMapping[platform] = inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.AccountStatsPricingRules != nil {
|
||||||
|
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
|
||||||
|
for i, rule := range c.AccountStatsPricingRules {
|
||||||
|
cp.AccountStatsPricingRules[i] = rule
|
||||||
|
if rule.GroupIDs != nil {
|
||||||
|
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
|
||||||
|
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
|
||||||
|
}
|
||||||
|
if rule.AccountIDs != nil {
|
||||||
|
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
|
||||||
|
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
|
||||||
|
}
|
||||||
|
if rule.Pricing != nil {
|
||||||
|
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
|
||||||
|
for j := range rule.Pricing {
|
||||||
|
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return &cp
|
return &cp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -416,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
|
|||||||
return ch.Clone(), nil
|
return ch.Clone(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupPlatform 获取分组的平台标识(从缓存)
|
||||||
|
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
|
||||||
|
cache, err := s.loadCache(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cache.groupPlatform[groupID]
|
||||||
|
}
|
||||||
|
|
||||||
// channelLookup 热路径公共查找结果
|
// channelLookup 热路径公共查找结果
|
||||||
type channelLookup struct {
|
type channelLookup struct {
|
||||||
cache *channelCache
|
cache *channelCache
|
||||||
@ -656,16 +665,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
|||||||
}
|
}
|
||||||
|
|
||||||
channel := &Channel{
|
channel := &Channel{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
BillingModelSource: input.BillingModelSource,
|
BillingModelSource: input.BillingModelSource,
|
||||||
RestrictModels: input.RestrictModels,
|
RestrictModels: input.RestrictModels,
|
||||||
GroupIDs: input.GroupIDs,
|
GroupIDs: input.GroupIDs,
|
||||||
ModelPricing: input.ModelPricing,
|
ModelPricing: input.ModelPricing,
|
||||||
ModelMapping: input.ModelMapping,
|
ModelMapping: input.ModelMapping,
|
||||||
Features: input.Features,
|
Features: input.Features,
|
||||||
FeaturesConfig: input.FeaturesConfig,
|
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
|
||||||
|
AccountStatsPricingRules: input.AccountStatsPricingRules,
|
||||||
}
|
}
|
||||||
if channel.BillingModelSource == "" {
|
if channel.BillingModelSource == "" {
|
||||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||||
@ -754,8 +764,11 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
|
|||||||
if input.BillingModelSource != "" {
|
if input.BillingModelSource != "" {
|
||||||
channel.BillingModelSource = input.BillingModelSource
|
channel.BillingModelSource = input.BillingModelSource
|
||||||
}
|
}
|
||||||
if input.FeaturesConfig != nil {
|
if input.ApplyPricingToAccountStats != nil {
|
||||||
channel.FeaturesConfig = input.FeaturesConfig
|
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
|
||||||
|
}
|
||||||
|
if input.AccountStatsPricingRules != nil {
|
||||||
|
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -922,27 +935,29 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
|
|||||||
|
|
||||||
// CreateChannelInput 创建渠道输入
|
// CreateChannelInput 创建渠道输入
|
||||||
type CreateChannelInput struct {
|
type CreateChannelInput struct {
|
||||||
Name string
|
Name string
|
||||||
Description string
|
Description string
|
||||||
GroupIDs []int64
|
GroupIDs []int64
|
||||||
ModelPricing []ChannelModelPricing
|
ModelPricing []ChannelModelPricing
|
||||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||||
BillingModelSource string
|
BillingModelSource string
|
||||||
RestrictModels bool
|
RestrictModels bool
|
||||||
Features string
|
Features string
|
||||||
FeaturesConfig map[string]any
|
ApplyPricingToAccountStats bool
|
||||||
|
AccountStatsPricingRules []AccountStatsPricingRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChannelInput 更新渠道输入
|
// UpdateChannelInput 更新渠道输入
|
||||||
type UpdateChannelInput struct {
|
type UpdateChannelInput struct {
|
||||||
Name string
|
Name string
|
||||||
Description *string
|
Description *string
|
||||||
Status string
|
Status string
|
||||||
GroupIDs *[]int64
|
GroupIDs *[]int64
|
||||||
ModelPricing *[]ChannelModelPricing
|
ModelPricing *[]ChannelModelPricing
|
||||||
ModelMapping map[string]map[string]string // platform → {src→dst}
|
ModelMapping map[string]map[string]string // platform → {src→dst}
|
||||||
BillingModelSource string
|
BillingModelSource string
|
||||||
RestrictModels *bool
|
RestrictModels *bool
|
||||||
Features *string
|
Features *string
|
||||||
FeaturesConfig map[string]any
|
ApplyPricingToAccountStats *bool
|
||||||
|
AccountStatsPricingRules *[]AccountStatsPricingRule
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7559,6 +7559,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||||
|
|
||||||
|
// 计算账号统计定价费用
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||||
|
ctx, s.channelService, s.billingService,
|
||||||
|
account.ID, *apiKey.GroupID, billingModel,
|
||||||
|
UsageTokens{
|
||||||
|
InputTokens: result.Usage.InputTokens,
|
||||||
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
|
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||||
|
},
|
||||||
|
1, // requestCount
|
||||||
|
"", // serviceTier: Anthropic 平台不使用 service tier
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||||
|
|||||||
@ -4569,6 +4569,15 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
usageLog.SubscriptionID = &subscription.ID
|
usageLog.SubscriptionID = &subscription.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算账号统计定价费用
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||||
|
ctx, s.channelService, s.billingService,
|
||||||
|
account.ID, *apiKey.GroupID, billingModel,
|
||||||
|
tokens, 1, serviceTier,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||||
|
|||||||
@ -146,6 +146,8 @@ type UsageLog struct {
|
|||||||
RateMultiplier float64
|
RateMultiplier float64
|
||||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
|
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
|
||||||
AccountRateMultiplier *float64
|
AccountRateMultiplier *float64
|
||||||
|
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
|
||||||
|
AccountStatsCost *float64
|
||||||
|
|
||||||
BillingType int8
|
BillingType int8
|
||||||
RequestType RequestType
|
RequestType RequestType
|
||||||
|
|||||||
38
backend/migrations/101_add_account_stats_pricing.sql
Normal file
38
backend/migrations/101_add_account_stats_pricing.sql
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
-- Account statistics pricing: allow channels to configure custom pricing for account cost tracking.
|
||||||
|
|
||||||
|
-- 1. Channel-level toggle
|
||||||
|
ALTER TABLE channels ADD COLUMN IF NOT EXISTS apply_pricing_to_account_stats BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
|
||||||
|
-- 2. Account stats pricing rules (ordered list per channel)
|
||||||
|
CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_rules (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
||||||
|
name VARCHAR(100) NOT NULL DEFAULT '',
|
||||||
|
group_ids BIGINT[] NOT NULL DEFAULT '{}',
|
||||||
|
account_ids BIGINT[] NOT NULL DEFAULT '{}',
|
||||||
|
sort_order INT NOT NULL DEFAULT 0,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cas_pricing_rules_channel_id ON channel_account_stats_pricing_rules(channel_id);
|
||||||
|
|
||||||
|
-- 3. Model pricing for each rule (same structure as channel_model_pricing)
|
||||||
|
CREATE TABLE IF NOT EXISTS channel_account_stats_model_pricing (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
rule_id BIGINT NOT NULL REFERENCES channel_account_stats_pricing_rules(id) ON DELETE CASCADE,
|
||||||
|
platform VARCHAR(50) NOT NULL DEFAULT '',
|
||||||
|
models JSONB NOT NULL DEFAULT '[]',
|
||||||
|
billing_mode VARCHAR(20) NOT NULL DEFAULT 'token',
|
||||||
|
input_price NUMERIC(20,10),
|
||||||
|
output_price NUMERIC(20,10),
|
||||||
|
cache_write_price NUMERIC(20,10),
|
||||||
|
cache_read_price NUMERIC(20,10),
|
||||||
|
image_output_price NUMERIC(20,10),
|
||||||
|
per_request_price NUMERIC(20,10),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cas_model_pricing_rule_id ON channel_account_stats_model_pricing(rule_id);
|
||||||
|
|
||||||
|
-- 4. Usage logs: pre-computed account stats cost (NULL = use default formula)
|
||||||
|
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS account_stats_cost NUMERIC(20,10);
|
||||||
@ -34,6 +34,14 @@ export interface ChannelModelPricing {
|
|||||||
intervals: PricingInterval[]
|
intervals: PricingInterval[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface AccountStatsPricingRule {
|
||||||
|
id?: number
|
||||||
|
name: string
|
||||||
|
group_ids: number[]
|
||||||
|
account_ids: number[]
|
||||||
|
pricing: ChannelModelPricing[]
|
||||||
|
}
|
||||||
|
|
||||||
export interface Channel {
|
export interface Channel {
|
||||||
id: number
|
id: number
|
||||||
name: string
|
name: string
|
||||||
@ -41,10 +49,11 @@ export interface Channel {
|
|||||||
status: string
|
status: string
|
||||||
billing_model_source: string // "requested" | "upstream"
|
billing_model_source: string // "requested" | "upstream"
|
||||||
restrict_models: boolean
|
restrict_models: boolean
|
||||||
features_config?: Record<string, unknown>
|
|
||||||
group_ids: number[]
|
group_ids: number[]
|
||||||
model_pricing: ChannelModelPricing[]
|
model_pricing: ChannelModelPricing[]
|
||||||
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
|
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
|
||||||
|
apply_pricing_to_account_stats: boolean
|
||||||
|
account_stats_pricing_rules: AccountStatsPricingRule[]
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
}
|
}
|
||||||
@ -57,7 +66,8 @@ export interface CreateChannelRequest {
|
|||||||
model_mapping?: Record<string, Record<string, string>>
|
model_mapping?: Record<string, Record<string, string>>
|
||||||
billing_model_source?: string
|
billing_model_source?: string
|
||||||
restrict_models?: boolean
|
restrict_models?: boolean
|
||||||
features_config?: Record<string, unknown>
|
apply_pricing_to_account_stats?: boolean
|
||||||
|
account_stats_pricing_rules?: AccountStatsPricingRule[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateChannelRequest {
|
export interface UpdateChannelRequest {
|
||||||
@ -69,7 +79,8 @@ export interface UpdateChannelRequest {
|
|||||||
model_mapping?: Record<string, Record<string, string>>
|
model_mapping?: Record<string, Record<string, string>>
|
||||||
billing_model_source?: string
|
billing_model_source?: string
|
||||||
restrict_models?: boolean
|
restrict_models?: boolean
|
||||||
features_config?: Record<string, unknown>
|
apply_pricing_to_account_stats?: boolean
|
||||||
|
account_stats_pricing_rules?: AccountStatsPricingRule[]
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PaginatedResponse<T> {
|
interface PaginatedResponse<T> {
|
||||||
|
|||||||
@ -1844,7 +1844,18 @@ export default {
|
|||||||
noPlatforms: 'Click "Add Platform" to start configuring the channel',
|
noPlatforms: 'Click "Add Platform" to start configuring the channel',
|
||||||
mappingCount: 'mappings',
|
mappingCount: 'mappings',
|
||||||
pricingEntry: 'Pricing Entry',
|
pricingEntry: 'Pricing Entry',
|
||||||
noModels: 'No models added'
|
noModels: 'No models added',
|
||||||
|
applyPricingToAccountStats: 'Apply Pricing to Account Stats',
|
||||||
|
applyPricingToAccountStatsDesc: 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.',
|
||||||
|
accountStatsPricingRules: 'Custom Account Stats Pricing Rules',
|
||||||
|
addRule: 'Add Rule',
|
||||||
|
noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.',
|
||||||
|
ruleName: 'Rule name (optional)',
|
||||||
|
ruleGroups: 'Groups',
|
||||||
|
ruleAccounts: 'Account IDs',
|
||||||
|
ruleAccountsPlaceholder: 'Enter account IDs, comma-separated',
|
||||||
|
ruleModelPricing: 'Model Pricing',
|
||||||
|
noGroupsInChannel: 'No groups selected in platform tabs above'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@ -1923,7 +1923,18 @@ export default {
|
|||||||
noPlatforms: '点击"添加平台"开始配置渠道',
|
noPlatforms: '点击"添加平台"开始配置渠道',
|
||||||
mappingCount: '条映射',
|
mappingCount: '条映射',
|
||||||
pricingEntry: '定价配置',
|
pricingEntry: '定价配置',
|
||||||
noModels: '未添加模型'
|
noModels: '未添加模型',
|
||||||
|
applyPricingToAccountStats: '应用模型定价到账号统计',
|
||||||
|
applyPricingToAccountStatsDesc: '启用后,账号统计费用将使用渠道模型定价计算。账号自身的统计倍率仍然生效。',
|
||||||
|
accountStatsPricingRules: '自定义账号统计定价规则',
|
||||||
|
addRule: '添加规则',
|
||||||
|
noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。',
|
||||||
|
ruleName: '规则名称(可选)',
|
||||||
|
ruleGroups: '分组',
|
||||||
|
ruleAccounts: '账号 ID',
|
||||||
|
ruleAccountsPlaceholder: '输入账号 ID,逗号分隔',
|
||||||
|
ruleModelPricing: '模型定价',
|
||||||
|
noGroupsInChannel: '上方平台标签页中未选择分组'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@ -306,24 +306,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Web Search Emulation (Anthropic only) -->
|
|
||||||
<div v-if="section.platform === 'anthropic'" class="border-t border-gray-200 pt-3 dark:border-dark-600">
|
|
||||||
<div class="flex items-center justify-between">
|
|
||||||
<div>
|
|
||||||
<label class="text-xs font-medium text-orange-600 dark:text-orange-400">
|
|
||||||
{{ t('admin.channels.form.webSearchEmulation') }}
|
|
||||||
</label>
|
|
||||||
<p v-if="webSearchGlobalEnabled" class="mt-0.5 text-[11px] text-amber-500 dark:text-amber-400">
|
|
||||||
{{ t('admin.channels.form.webSearchEmulationHint') }}
|
|
||||||
</p>
|
|
||||||
<p v-else class="mt-0.5 text-[11px] text-gray-400">
|
|
||||||
{{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
<Toggle v-model="section.web_search_emulation" :disabled="!webSearchGlobalEnabled" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Model Mapping -->
|
<!-- Model Mapping -->
|
||||||
<div>
|
<div>
|
||||||
<div class="mb-1 flex items-center justify-between">
|
<div class="mb-1 flex items-center justify-between">
|
||||||
@ -398,6 +380,143 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Account Stats Pricing (always visible, not tied to platform tabs) -->
|
||||||
|
<div class="mt-6 border-t border-gray-200 pt-4 dark:border-dark-700">
|
||||||
|
<!-- Toggle -->
|
||||||
|
<div class="flex items-center justify-between mb-3">
|
||||||
|
<div>
|
||||||
|
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.channels.form.applyPricingToAccountStats', 'Apply Pricing to Account Stats') }}
|
||||||
|
</label>
|
||||||
|
<p class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.channels.form.applyPricingToAccountStatsDesc', 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Toggle
|
||||||
|
:modelValue="form.apply_pricing_to_account_stats"
|
||||||
|
@update:modelValue="form.apply_pricing_to_account_stats = $event"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Custom rules (only when toggle is on) -->
|
||||||
|
<div v-if="form.apply_pricing_to_account_stats" class="mt-4 space-y-4">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<h4 class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.channels.form.accountStatsPricingRules', 'Custom Account Stats Pricing Rules') }}
|
||||||
|
</h4>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="addAccountStatsRule()"
|
||||||
|
class="rounded-lg border border-primary-300 px-3 py-1 text-xs font-medium text-primary-600 hover:bg-primary-50 dark:border-primary-600 dark:text-primary-400 dark:hover:bg-primary-900/20"
|
||||||
|
>
|
||||||
|
+ {{ t('admin.channels.form.addRule', 'Add Rule') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p
|
||||||
|
v-if="form.account_stats_pricing_rules.length === 0"
|
||||||
|
class="text-xs italic text-gray-400 dark:text-gray-500"
|
||||||
|
>
|
||||||
|
{{ t('admin.channels.form.noRulesConfigured', 'No custom rules configured. Channel model pricing above will be used.') }}
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<!-- Rule cards -->
|
||||||
|
<div
|
||||||
|
v-for="(rule, ruleIndex) in form.account_stats_pricing_rules"
|
||||||
|
:key="ruleIndex"
|
||||||
|
class="space-y-3 rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<input
|
||||||
|
v-model="rule.name"
|
||||||
|
:placeholder="t('admin.channels.form.ruleName', 'Rule name (optional)')"
|
||||||
|
class="bg-transparent text-sm font-medium text-gray-700 placeholder-gray-400 outline-none dark:text-gray-300"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="removeAccountStatsRule(ruleIndex)"
|
||||||
|
class="text-xs text-red-500 hover:text-red-700"
|
||||||
|
>
|
||||||
|
{{ t('common.delete', 'Delete') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Group selection (multi-select from channel's groups) -->
|
||||||
|
<div>
|
||||||
|
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.channels.form.ruleGroups', 'Groups') }}
|
||||||
|
</label>
|
||||||
|
<div class="mt-1 flex flex-wrap gap-1">
|
||||||
|
<label
|
||||||
|
v-for="gid in allFormGroupIds"
|
||||||
|
:key="gid"
|
||||||
|
class="inline-flex cursor-pointer items-center gap-1 rounded-md border px-2 py-1 text-xs transition-colors"
|
||||||
|
:class="rule.group_ids.includes(gid)
|
||||||
|
? 'border-primary-300 bg-primary-50 dark:border-primary-700 dark:bg-primary-900/20'
|
||||||
|
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="rule.group_ids.includes(gid)"
|
||||||
|
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||||
|
@change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)"
|
||||||
|
/>
|
||||||
|
<span>{{ getGroupNameById(gid) }}</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<p v-if="allFormGroupIds.length === 0" class="mt-1 text-xs text-gray-400">
|
||||||
|
{{ t('admin.channels.form.noGroupsInChannel', 'No groups selected in platform tabs above') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Account IDs input -->
|
||||||
|
<div>
|
||||||
|
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.channels.form.ruleAccounts', 'Account IDs') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
:value="rule.account_ids.join(', ')"
|
||||||
|
@change="rule.account_ids = parseAccountIdsInput(($event.target as HTMLInputElement).value)"
|
||||||
|
:placeholder="t('admin.channels.form.ruleAccountsPlaceholder', 'Enter account IDs, comma-separated')"
|
||||||
|
class="input mt-1 text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Model Pricing entries -->
|
||||||
|
<div>
|
||||||
|
<div class="mb-1 flex items-center justify-between">
|
||||||
|
<label class="text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.channels.form.ruleModelPricing', 'Model Pricing') }}
|
||||||
|
</label>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="addRulePricingEntry(ruleIndex)"
|
||||||
|
class="text-xs text-primary-600 hover:text-primary-700"
|
||||||
|
>
|
||||||
|
+ {{ t('common.add', 'Add') }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="rule.pricing.length === 0"
|
||||||
|
class="rounded border border-dashed border-gray-300 p-2 text-center text-xs text-gray-400 dark:border-dark-500"
|
||||||
|
>
|
||||||
|
{{ t('admin.channels.form.noPricingRules', 'No pricing rules yet. Click "Add" to create one.') }}
|
||||||
|
</div>
|
||||||
|
<div v-else class="space-y-2">
|
||||||
|
<PricingEntryCard
|
||||||
|
v-for="(entry, pIdx) in rule.pricing"
|
||||||
|
:key="pIdx"
|
||||||
|
:entry="entry"
|
||||||
|
platform=""
|
||||||
|
@update="rule.pricing.splice(pIdx, 1, $event)"
|
||||||
|
@remove="removeRulePricingEntry(ruleIndex, pIdx)"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -441,9 +560,8 @@
|
|||||||
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { extractApiErrorMessage } from '@/utils/apiError'
|
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
|
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest, AccountStatsPricingRule } from '@/api/admin/channels'
|
||||||
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
import type { PricingFormEntry } from '@/components/admin/channel/types'
|
||||||
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
||||||
import type { AdminGroup, GroupPlatform } from '@/types'
|
import type { AdminGroup, GroupPlatform } from '@/types'
|
||||||
@ -465,18 +583,6 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
|
|||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
|
|
||||||
// Web Search global enabled state (loaded once on mount)
|
|
||||||
const webSearchGlobalEnabled = ref(false)
|
|
||||||
async function loadWebSearchGlobalState() {
|
|
||||||
try {
|
|
||||||
const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
|
|
||||||
webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
|
|
||||||
} catch (err: unknown) {
|
|
||||||
console.warn('Failed to load web search global state:', err)
|
|
||||||
webSearchGlobalEnabled.value = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Platform Section type ──
|
// ── Platform Section type ──
|
||||||
interface PlatformSection {
|
interface PlatformSection {
|
||||||
platform: GroupPlatform
|
platform: GroupPlatform
|
||||||
@ -485,7 +591,6 @@ interface PlatformSection {
|
|||||||
group_ids: number[]
|
group_ids: number[]
|
||||||
model_mapping: Record<string, string>
|
model_mapping: Record<string, string>
|
||||||
model_pricing: PricingFormEntry[]
|
model_pricing: PricingFormEntry[]
|
||||||
web_search_emulation: boolean
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Table columns ──
|
// ── Table columns ──
|
||||||
@ -553,7 +658,14 @@ const form = reactive({
|
|||||||
status: 'active',
|
status: 'active',
|
||||||
restrict_models: false,
|
restrict_models: false,
|
||||||
billing_model_source: 'channel_mapped' as string,
|
billing_model_source: 'channel_mapped' as string,
|
||||||
platforms: [] as PlatformSection[]
|
platforms: [] as PlatformSection[],
|
||||||
|
apply_pricing_to_account_stats: false,
|
||||||
|
account_stats_pricing_rules: [] as Array<{
|
||||||
|
name: string
|
||||||
|
group_ids: number[]
|
||||||
|
account_ids: number[]
|
||||||
|
pricing: PricingFormEntry[]
|
||||||
|
}>
|
||||||
})
|
})
|
||||||
|
|
||||||
let abortController: AbortController | null = null
|
let abortController: AbortController | null = null
|
||||||
@ -597,8 +709,7 @@ function addPlatformSection(platform: GroupPlatform) {
|
|||||||
collapsed: false,
|
collapsed: false,
|
||||||
group_ids: [],
|
group_ids: [],
|
||||||
model_mapping: {},
|
model_mapping: {},
|
||||||
model_pricing: [],
|
model_pricing: []
|
||||||
web_search_emulation: false,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -711,15 +822,89 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
|
|||||||
mapping[newKey] = value
|
mapping[newKey] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Account Stats Pricing helpers ──
|
||||||
|
function addAccountStatsRule() {
|
||||||
|
form.account_stats_pricing_rules.push({
|
||||||
|
name: '',
|
||||||
|
group_ids: [],
|
||||||
|
account_ids: [],
|
||||||
|
pricing: []
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function addRulePricingEntry(ruleIndex: number) {
|
||||||
|
form.account_stats_pricing_rules[ruleIndex].pricing.push({
|
||||||
|
models: [],
|
||||||
|
billing_mode: 'token',
|
||||||
|
input_price: null,
|
||||||
|
output_price: null,
|
||||||
|
cache_write_price: null,
|
||||||
|
cache_read_price: null,
|
||||||
|
image_output_price: null,
|
||||||
|
per_request_price: null,
|
||||||
|
intervals: []
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeAccountStatsRule(ruleIndex: number) {
|
||||||
|
form.account_stats_pricing_rules.splice(ruleIndex, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) {
|
||||||
|
form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
function getGroupNameById(groupId: number): string {
|
||||||
|
const group = allGroups.value.find(g => g.id === groupId)
|
||||||
|
return group ? group.name : `#${groupId}`
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Collect all group_ids from enabled platform sections */
|
||||||
|
const allFormGroupIds = computed(() => {
|
||||||
|
const ids = new Set<number>()
|
||||||
|
for (const section of form.platforms) {
|
||||||
|
if (!section.enabled) continue
|
||||||
|
for (const gid of section.group_ids) {
|
||||||
|
ids.add(gid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return [...ids]
|
||||||
|
})
|
||||||
|
|
||||||
|
function parseAccountIdsInput(value: string): number[] {
|
||||||
|
return value
|
||||||
|
.split(',')
|
||||||
|
.map(s => parseInt(s.trim()))
|
||||||
|
.filter(n => !isNaN(n) && n > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
|
||||||
|
return form.account_stats_pricing_rules.map(rule => ({
|
||||||
|
name: rule.name,
|
||||||
|
group_ids: rule.group_ids,
|
||||||
|
account_ids: rule.account_ids,
|
||||||
|
pricing: rule.pricing
|
||||||
|
.filter(p => p.models.length > 0)
|
||||||
|
.map(p => ({
|
||||||
|
platform: '',
|
||||||
|
models: p.models,
|
||||||
|
billing_mode: p.billing_mode,
|
||||||
|
input_price: mTokToPerToken(p.input_price),
|
||||||
|
output_price: mTokToPerToken(p.output_price),
|
||||||
|
cache_write_price: mTokToPerToken(p.cache_write_price),
|
||||||
|
cache_read_price: mTokToPerToken(p.cache_read_price),
|
||||||
|
image_output_price: mTokToPerToken(p.image_output_price),
|
||||||
|
per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null,
|
||||||
|
intervals: formIntervalsToAPI(p.intervals || [])
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
// ── Form ↔ API conversion ──
|
// ── Form ↔ API conversion ──
|
||||||
function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record<string, Record<string, string>>, features_config: Record<string, unknown> } {
|
function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record<string, Record<string, string>> } {
|
||||||
const group_ids: number[] = []
|
const group_ids: number[] = []
|
||||||
const model_pricing: ChannelModelPricing[] = []
|
const model_pricing: ChannelModelPricing[] = []
|
||||||
const model_mapping: Record<string, Record<string, string>> = {}
|
const model_mapping: Record<string, Record<string, string>> = {}
|
||||||
// Preserve existing features_config fields not managed by the form
|
|
||||||
const featuresConfig: Record<string, unknown> = editingChannel.value?.features_config
|
|
||||||
? { ...editingChannel.value.features_config }
|
|
||||||
: {}
|
|
||||||
|
|
||||||
for (const section of form.platforms) {
|
for (const section of form.platforms) {
|
||||||
if (!section.enabled) continue
|
if (!section.enabled) continue
|
||||||
@ -748,19 +933,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect web_search_emulation (only anthropic platform supports it)
|
return { group_ids, model_pricing, model_mapping }
|
||||||
const wsEmulation: Record<string, boolean> = {}
|
|
||||||
for (const section of form.platforms) {
|
|
||||||
if (!section.enabled) continue
|
|
||||||
if (section.web_search_emulation && section.platform === 'anthropic') {
|
|
||||||
wsEmulation[section.platform] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Object.keys(wsEmulation).length > 0) {
|
|
||||||
featuresConfig.web_search_emulation = wsEmulation
|
|
||||||
}
|
|
||||||
|
|
||||||
return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function apiToForm(channel: Channel): PlatformSection[] {
|
function apiToForm(channel: Channel): PlatformSection[] {
|
||||||
@ -804,19 +977,13 @@ function apiToForm(channel: Channel): PlatformSection[] {
|
|||||||
intervals: apiIntervalsToForm(p.intervals || [])
|
intervals: apiIntervalsToForm(p.intervals || [])
|
||||||
} as PricingFormEntry))
|
} as PricingFormEntry))
|
||||||
|
|
||||||
// Read web_search_emulation from features_config
|
|
||||||
const fc = channel.features_config
|
|
||||||
const wsEmulation = fc?.web_search_emulation as Record<string, boolean> | undefined
|
|
||||||
const webSearchEnabled = wsEmulation?.[platform] === true
|
|
||||||
|
|
||||||
sections.push({
|
sections.push({
|
||||||
platform,
|
platform,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
collapsed: false,
|
collapsed: false,
|
||||||
group_ids: groupIds,
|
group_ids: groupIds,
|
||||||
model_mapping: { ...mapping },
|
model_mapping: { ...mapping },
|
||||||
model_pricing: pricing,
|
model_pricing: pricing
|
||||||
web_search_emulation: webSearchEnabled,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -841,10 +1008,10 @@ async function loadChannels() {
|
|||||||
if (ctrl.signal.aborted || abortController !== ctrl) return
|
if (ctrl.signal.aborted || abortController !== ctrl) return
|
||||||
channels.value = response.items || []
|
channels.value = response.items || []
|
||||||
pagination.total = response.total
|
pagination.total = response.total
|
||||||
} catch (error: unknown) {
|
} catch (error: any) {
|
||||||
const e = error as { name?: string; code?: string }
|
if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
|
||||||
if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
|
appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
|
||||||
appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
|
console.error('Error loading channels:', error)
|
||||||
} finally {
|
} finally {
|
||||||
if (abortController === ctrl) {
|
if (abortController === ctrl) {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@ -909,6 +1076,8 @@ function resetForm() {
|
|||||||
form.restrict_models = false
|
form.restrict_models = false
|
||||||
form.billing_model_source = 'channel_mapped'
|
form.billing_model_source = 'channel_mapped'
|
||||||
form.platforms = []
|
form.platforms = []
|
||||||
|
form.apply_pricing_to_account_stats = false
|
||||||
|
form.account_stats_pricing_rules = []
|
||||||
activeTab.value = 'basic'
|
activeTab.value = 'basic'
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -926,6 +1095,23 @@ async function openEditDialog(channel: Channel) {
|
|||||||
form.status = channel.status
|
form.status = channel.status
|
||||||
form.restrict_models = channel.restrict_models || false
|
form.restrict_models = channel.restrict_models || false
|
||||||
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
|
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
|
||||||
|
form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false
|
||||||
|
form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({
|
||||||
|
name: rule.name || '',
|
||||||
|
group_ids: [...(rule.group_ids || [])],
|
||||||
|
account_ids: [...(rule.account_ids || [])],
|
||||||
|
pricing: (rule.pricing || []).map(p => ({
|
||||||
|
models: [...(p.models || [])],
|
||||||
|
billing_mode: p.billing_mode,
|
||||||
|
input_price: perTokenToMTok(p.input_price),
|
||||||
|
output_price: perTokenToMTok(p.output_price),
|
||||||
|
cache_write_price: perTokenToMTok(p.cache_write_price),
|
||||||
|
cache_read_price: perTokenToMTok(p.cache_read_price),
|
||||||
|
image_output_price: perTokenToMTok(p.image_output_price),
|
||||||
|
per_request_price: p.per_request_price,
|
||||||
|
intervals: apiIntervalsToForm(p.intervals || [])
|
||||||
|
} as PricingFormEntry))
|
||||||
|
}))
|
||||||
// Must load groups first so apiToForm can map groupID → platform
|
// Must load groups first so apiToForm can map groupID → platform
|
||||||
await Promise.all([loadGroups(), loadAllChannelsForConflict()])
|
await Promise.all([loadGroups(), loadAllChannelsForConflict()])
|
||||||
form.platforms = apiToForm(channel)
|
form.platforms = apiToForm(channel)
|
||||||
@ -1024,7 +1210,7 @@ async function handleSubmit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
|
const { group_ids, model_pricing, model_mapping } = formToAPI()
|
||||||
|
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
@ -1038,7 +1224,8 @@ async function handleSubmit() {
|
|||||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||||
billing_model_source: form.billing_model_source,
|
billing_model_source: form.billing_model_source,
|
||||||
restrict_models: form.restrict_models,
|
restrict_models: form.restrict_models,
|
||||||
features_config,
|
apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
|
||||||
|
account_stats_pricing_rules: accountStatsRulesToAPI()
|
||||||
}
|
}
|
||||||
await adminAPI.channels.update(editingChannel.value.id, req)
|
await adminAPI.channels.update(editingChannel.value.id, req)
|
||||||
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
|
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
|
||||||
@ -1051,17 +1238,20 @@ async function handleSubmit() {
|
|||||||
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
|
||||||
billing_model_source: form.billing_model_source,
|
billing_model_source: form.billing_model_source,
|
||||||
restrict_models: form.restrict_models,
|
restrict_models: form.restrict_models,
|
||||||
features_config,
|
apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
|
||||||
|
account_stats_pricing_rules: accountStatsRulesToAPI()
|
||||||
}
|
}
|
||||||
await adminAPI.channels.create(req)
|
await adminAPI.channels.create(req)
|
||||||
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
|
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
|
||||||
}
|
}
|
||||||
closeDialog()
|
closeDialog()
|
||||||
loadChannels()
|
loadChannels()
|
||||||
} catch (error: unknown) {
|
} catch (error: any) {
|
||||||
appStore.showError(extractApiErrorMessage(error, editingChannel.value
|
const msg = error.response?.data?.detail || (editingChannel.value
|
||||||
? t('admin.channels.updateError', 'Failed to update channel')
|
? t('admin.channels.updateError', 'Failed to update channel')
|
||||||
: t('admin.channels.createError', 'Failed to create channel')))
|
: t('admin.channels.createError', 'Failed to create channel'))
|
||||||
|
appStore.showError(msg)
|
||||||
|
console.error('Error saving channel:', error)
|
||||||
} finally {
|
} finally {
|
||||||
submitting.value = false
|
submitting.value = false
|
||||||
}
|
}
|
||||||
@ -1099,8 +1289,9 @@ async function confirmDelete() {
|
|||||||
showDeleteDialog.value = false
|
showDeleteDialog.value = false
|
||||||
deletingChannel.value = null
|
deletingChannel.value = null
|
||||||
loadChannels()
|
loadChannels()
|
||||||
} catch (error: unknown) {
|
} catch (error: any) {
|
||||||
appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
|
appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
|
||||||
|
console.error('Error deleting channel:', error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1108,7 +1299,6 @@ async function confirmDelete() {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadChannels()
|
loadChannels()
|
||||||
loadGroups()
|
loadGroups()
|
||||||
loadWebSearchGlobalState()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user