Merge pull request #1455 from touwaeriol/feat/channel-management

feat(channel): add channel management with multi-mode pricing and billing integration
This commit is contained in:
Wesley Liddick 2026-04-04 23:42:33 +08:00 committed by GitHub
commit bf45581104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
101 changed files with 12458 additions and 550 deletions

View File

@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient) refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
channelRepository := repository.NewChannelRepository(db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
@ -138,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient) internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
@ -175,9 +176,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@ -213,7 +216,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler) channelHandler := admin.NewChannelHandler(channelService, billingService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)

View File

@ -744,6 +744,10 @@ var (
{Name: "model", Type: field.TypeString, Size: 100}, {Name: "model", Type: field.TypeString, Size: 100},
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "channel_id", Type: field.TypeInt64, Nullable: true},
{Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500},
{Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50},
{Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20},
{Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@ -783,31 +787,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[35]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[36]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[33]}, Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[34]}, Columns: []*schema.Column{UsageLogsColumns[38]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
@ -816,32 +820,32 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33]}, Columns: []*schema.Column{UsageLogsColumns[37]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[34]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[35]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[36]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[34]}, Columns: []*schema.Column{UsageLogsColumns[38]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[33]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
@ -861,17 +865,17 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]},
}, },
{ {
Name: "usagelog_group_id_created_at", Name: "usagelog_group_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]},
}, },
}, },
} }

View File

@ -19725,6 +19725,11 @@ type UsageLogMutation struct {
model *string model *string
requested_model *string requested_model *string
upstream_model *string upstream_model *string
channel_id *int64
addchannel_id *int64
model_mapping_chain *string
billing_tier *string
billing_mode *string
input_tokens *int input_tokens *int
addinput_tokens *int addinput_tokens *int
output_tokens *int output_tokens *int
@ -20160,6 +20165,223 @@ func (m *UsageLogMutation) ResetUpstreamModel() {
delete(m.clearedFields, usagelog.FieldUpstreamModel) delete(m.clearedFields, usagelog.FieldUpstreamModel)
} }
// SetChannelID sets the "channel_id" field.
func (m *UsageLogMutation) SetChannelID(i int64) {
m.channel_id = &i
m.addchannel_id = nil
}
// ChannelID returns the value of the "channel_id" field in the mutation.
func (m *UsageLogMutation) ChannelID() (r int64, exists bool) {
v := m.channel_id
if v == nil {
return
}
return *v, true
}
// OldChannelID returns the old "channel_id" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldChannelID(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldChannelID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldChannelID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldChannelID: %w", err)
}
return oldValue.ChannelID, nil
}
// AddChannelID adds i to the "channel_id" field.
func (m *UsageLogMutation) AddChannelID(i int64) {
if m.addchannel_id != nil {
*m.addchannel_id += i
} else {
m.addchannel_id = &i
}
}
// AddedChannelID returns the value that was added to the "channel_id" field in this mutation.
func (m *UsageLogMutation) AddedChannelID() (r int64, exists bool) {
v := m.addchannel_id
if v == nil {
return
}
return *v, true
}
// ClearChannelID clears the value of the "channel_id" field.
func (m *UsageLogMutation) ClearChannelID() {
m.channel_id = nil
m.addchannel_id = nil
m.clearedFields[usagelog.FieldChannelID] = struct{}{}
}
// ChannelIDCleared returns if the "channel_id" field was cleared in this mutation.
func (m *UsageLogMutation) ChannelIDCleared() bool {
_, ok := m.clearedFields[usagelog.FieldChannelID]
return ok
}
// ResetChannelID resets all changes to the "channel_id" field.
func (m *UsageLogMutation) ResetChannelID() {
m.channel_id = nil
m.addchannel_id = nil
delete(m.clearedFields, usagelog.FieldChannelID)
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (m *UsageLogMutation) SetModelMappingChain(s string) {
m.model_mapping_chain = &s
}
// ModelMappingChain returns the value of the "model_mapping_chain" field in the mutation.
func (m *UsageLogMutation) ModelMappingChain() (r string, exists bool) {
v := m.model_mapping_chain
if v == nil {
return
}
return *v, true
}
// OldModelMappingChain returns the old "model_mapping_chain" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldModelMappingChain(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldModelMappingChain is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldModelMappingChain requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldModelMappingChain: %w", err)
}
return oldValue.ModelMappingChain, nil
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (m *UsageLogMutation) ClearModelMappingChain() {
m.model_mapping_chain = nil
m.clearedFields[usagelog.FieldModelMappingChain] = struct{}{}
}
// ModelMappingChainCleared returns if the "model_mapping_chain" field was cleared in this mutation.
func (m *UsageLogMutation) ModelMappingChainCleared() bool {
_, ok := m.clearedFields[usagelog.FieldModelMappingChain]
return ok
}
// ResetModelMappingChain resets all changes to the "model_mapping_chain" field.
func (m *UsageLogMutation) ResetModelMappingChain() {
m.model_mapping_chain = nil
delete(m.clearedFields, usagelog.FieldModelMappingChain)
}
// SetBillingTier sets the "billing_tier" field.
func (m *UsageLogMutation) SetBillingTier(s string) {
m.billing_tier = &s
}
// BillingTier returns the value of the "billing_tier" field in the mutation.
func (m *UsageLogMutation) BillingTier() (r string, exists bool) {
v := m.billing_tier
if v == nil {
return
}
return *v, true
}
// OldBillingTier returns the old "billing_tier" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldBillingTier(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBillingTier is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBillingTier requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBillingTier: %w", err)
}
return oldValue.BillingTier, nil
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (m *UsageLogMutation) ClearBillingTier() {
m.billing_tier = nil
m.clearedFields[usagelog.FieldBillingTier] = struct{}{}
}
// BillingTierCleared returns if the "billing_tier" field was cleared in this mutation.
func (m *UsageLogMutation) BillingTierCleared() bool {
_, ok := m.clearedFields[usagelog.FieldBillingTier]
return ok
}
// ResetBillingTier resets all changes to the "billing_tier" field.
func (m *UsageLogMutation) ResetBillingTier() {
m.billing_tier = nil
delete(m.clearedFields, usagelog.FieldBillingTier)
}
// SetBillingMode sets the "billing_mode" field.
func (m *UsageLogMutation) SetBillingMode(s string) {
m.billing_mode = &s
}
// BillingMode returns the value of the "billing_mode" field in the mutation.
func (m *UsageLogMutation) BillingMode() (r string, exists bool) {
v := m.billing_mode
if v == nil {
return
}
return *v, true
}
// OldBillingMode returns the old "billing_mode" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldBillingMode(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBillingMode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBillingMode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBillingMode: %w", err)
}
return oldValue.BillingMode, nil
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (m *UsageLogMutation) ClearBillingMode() {
m.billing_mode = nil
m.clearedFields[usagelog.FieldBillingMode] = struct{}{}
}
// BillingModeCleared returns if the "billing_mode" field was cleared in this mutation.
func (m *UsageLogMutation) BillingModeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldBillingMode]
return ok
}
// ResetBillingMode resets all changes to the "billing_mode" field.
func (m *UsageLogMutation) ResetBillingMode() {
m.billing_mode = nil
delete(m.clearedFields, usagelog.FieldBillingMode)
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) { func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i m.group = &i
@ -21781,7 +22003,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 34) fields := make([]string, 0, 38)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
@ -21803,6 +22025,18 @@ func (m *UsageLogMutation) Fields() []string {
if m.upstream_model != nil { if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel) fields = append(fields, usagelog.FieldUpstreamModel)
} }
if m.channel_id != nil {
fields = append(fields, usagelog.FieldChannelID)
}
if m.model_mapping_chain != nil {
fields = append(fields, usagelog.FieldModelMappingChain)
}
if m.billing_tier != nil {
fields = append(fields, usagelog.FieldBillingTier)
}
if m.billing_mode != nil {
fields = append(fields, usagelog.FieldBillingMode)
}
if m.group != nil { if m.group != nil {
fields = append(fields, usagelog.FieldGroupID) fields = append(fields, usagelog.FieldGroupID)
} }
@ -21906,6 +22140,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestedModel() return m.RequestedModel()
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
return m.UpstreamModel() return m.UpstreamModel()
case usagelog.FieldChannelID:
return m.ChannelID()
case usagelog.FieldModelMappingChain:
return m.ModelMappingChain()
case usagelog.FieldBillingTier:
return m.BillingTier()
case usagelog.FieldBillingMode:
return m.BillingMode()
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
return m.GroupID() return m.GroupID()
case usagelog.FieldSubscriptionID: case usagelog.FieldSubscriptionID:
@ -21983,6 +22225,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestedModel(ctx) return m.OldRequestedModel(ctx)
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx) return m.OldUpstreamModel(ctx)
case usagelog.FieldChannelID:
return m.OldChannelID(ctx)
case usagelog.FieldModelMappingChain:
return m.OldModelMappingChain(ctx)
case usagelog.FieldBillingTier:
return m.OldBillingTier(ctx)
case usagelog.FieldBillingMode:
return m.OldBillingMode(ctx)
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
return m.OldGroupID(ctx) return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID: case usagelog.FieldSubscriptionID:
@ -22095,6 +22345,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetUpstreamModel(v) m.SetUpstreamModel(v)
return nil return nil
case usagelog.FieldChannelID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetChannelID(v)
return nil
case usagelog.FieldModelMappingChain:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetModelMappingChain(v)
return nil
case usagelog.FieldBillingTier:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBillingTier(v)
return nil
case usagelog.FieldBillingMode:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBillingMode(v)
return nil
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
v, ok := value.(int64) v, ok := value.(int64)
if !ok { if !ok {
@ -22292,6 +22570,9 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
// this mutation. // this mutation.
func (m *UsageLogMutation) AddedFields() []string { func (m *UsageLogMutation) AddedFields() []string {
var fields []string var fields []string
if m.addchannel_id != nil {
fields = append(fields, usagelog.FieldChannelID)
}
if m.addinput_tokens != nil { if m.addinput_tokens != nil {
fields = append(fields, usagelog.FieldInputTokens) fields = append(fields, usagelog.FieldInputTokens)
} }
@ -22354,6 +22635,8 @@ func (m *UsageLogMutation) AddedFields() []string {
// was not set, or was not defined in the schema. // was not set, or was not defined in the schema.
func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
switch name { switch name {
case usagelog.FieldChannelID:
return m.AddedChannelID()
case usagelog.FieldInputTokens: case usagelog.FieldInputTokens:
return m.AddedInputTokens() return m.AddedInputTokens()
case usagelog.FieldOutputTokens: case usagelog.FieldOutputTokens:
@ -22399,6 +22682,13 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
// type. // type.
func (m *UsageLogMutation) AddField(name string, value ent.Value) error { func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
switch name { switch name {
case usagelog.FieldChannelID:
v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddChannelID(v)
return nil
case usagelog.FieldInputTokens: case usagelog.FieldInputTokens:
v, ok := value.(int) v, ok := value.(int)
if !ok { if !ok {
@ -22539,6 +22829,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldUpstreamModel) { if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel) fields = append(fields, usagelog.FieldUpstreamModel)
} }
if m.FieldCleared(usagelog.FieldChannelID) {
fields = append(fields, usagelog.FieldChannelID)
}
if m.FieldCleared(usagelog.FieldModelMappingChain) {
fields = append(fields, usagelog.FieldModelMappingChain)
}
if m.FieldCleared(usagelog.FieldBillingTier) {
fields = append(fields, usagelog.FieldBillingTier)
}
if m.FieldCleared(usagelog.FieldBillingMode) {
fields = append(fields, usagelog.FieldBillingMode)
}
if m.FieldCleared(usagelog.FieldGroupID) { if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID) fields = append(fields, usagelog.FieldGroupID)
} }
@ -22586,6 +22888,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel() m.ClearUpstreamModel()
return nil return nil
case usagelog.FieldChannelID:
m.ClearChannelID()
return nil
case usagelog.FieldModelMappingChain:
m.ClearModelMappingChain()
return nil
case usagelog.FieldBillingTier:
m.ClearBillingTier()
return nil
case usagelog.FieldBillingMode:
m.ClearBillingMode()
return nil
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
m.ClearGroupID() m.ClearGroupID()
return nil return nil
@ -22642,6 +22956,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldUpstreamModel: case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel() m.ResetUpstreamModel()
return nil return nil
case usagelog.FieldChannelID:
m.ResetChannelID()
return nil
case usagelog.FieldModelMappingChain:
m.ResetModelMappingChain()
return nil
case usagelog.FieldBillingTier:
m.ResetBillingTier()
return nil
case usagelog.FieldBillingMode:
m.ResetBillingMode()
return nil
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
m.ResetGroupID() m.ResetGroupID()
return nil return nil

View File

@ -875,92 +875,104 @@ func init() {
usagelogDescUpstreamModel := usagelogFields[6].Descriptor() usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field.
usagelogDescModelMappingChain := usagelogFields[8].Descriptor()
// usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error)
// usagelogDescBillingTier is the schema descriptor for billing_tier field.
usagelogDescBillingTier := usagelogFields[9].Descriptor()
// usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error)
// usagelogDescBillingMode is the schema descriptor for billing_mode field.
usagelogDescBillingMode := usagelogFields[10].Descriptor()
// usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field. // usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[9].Descriptor() usagelogDescInputTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field. // usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[10].Descriptor() usagelogDescOutputTokens := usagelogFields[14].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor() usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[12].Descriptor() usagelogDescCacheReadTokens := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor() usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor() usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field. // usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[15].Descriptor() usagelogDescInputCost := usagelogFields[19].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field. // usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field. // usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[16].Descriptor() usagelogDescOutputCost := usagelogFields[20].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[17].Descriptor() usagelogDescCacheCreationCost := usagelogFields[21].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[18].Descriptor() usagelogDescCacheReadCost := usagelogFields[22].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field. // usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[19].Descriptor() usagelogDescTotalCost := usagelogFields[23].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field. // usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[20].Descriptor() usagelogDescActualCost := usagelogFields[24].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[21].Descriptor() usagelogDescRateMultiplier := usagelogFields[25].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field. // usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[23].Descriptor() usagelogDescBillingType := usagelogFields[27].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field. // usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field. // usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[24].Descriptor() usagelogDescStream := usagelogFields[28].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field. // usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool) usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field. // usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[27].Descriptor() usagelogDescUserAgent := usagelogFields[31].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field. // usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[28].Descriptor() usagelogDescIPAddress := usagelogFields[32].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field. // usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[29].Descriptor() usagelogDescImageCount := usagelogFields[33].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field. // usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field. // usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[30].Descriptor() usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field. // usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[31].Descriptor() usagelogDescMediaType := usagelogFields[35].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor() usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[33].Descriptor() usagelogDescCreatedAt := usagelogFields[37].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()

View File

@ -53,6 +53,10 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(100). MaxLen(100).
Optional(). Optional().
Nillable(), Nillable(),
field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"),
field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式token/per_request/image"),
field.Int64("group_id"). field.Int64("group_id").
Optional(). Optional().
Nillable(), Nillable(),

View File

@ -36,6 +36,14 @@ type UsageLog struct {
RequestedModel *string `json:"requested_model,omitempty"` RequestedModel *string `json:"requested_model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field. // UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"` UpstreamModel *string `json:"upstream_model,omitempty"`
// 渠道 ID
ChannelID *int64 `json:"channel_id,omitempty"`
// 模型映射链
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
// 计费层级标签
BillingTier *string `json:"billing_tier,omitempty"`
// 计费模式token/per_request/image
BillingMode *string `json:"billing_mode,omitempty"`
// GroupID holds the value of the "group_id" field. // GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field. // SubscriptionID holds the value of the "subscription_id" field.
@ -177,9 +185,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -248,6 +256,34 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UpstreamModel = new(string) _m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String *_m.UpstreamModel = value.String
} }
case usagelog.FieldChannelID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field channel_id", values[i])
} else if value.Valid {
_m.ChannelID = new(int64)
*_m.ChannelID = value.Int64
}
case usagelog.FieldModelMappingChain:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i])
} else if value.Valid {
_m.ModelMappingChain = new(string)
*_m.ModelMappingChain = value.String
}
case usagelog.FieldBillingTier:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field billing_tier", values[i])
} else if value.Valid {
_m.BillingTier = new(string)
*_m.BillingTier = value.String
}
case usagelog.FieldBillingMode:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field billing_mode", values[i])
} else if value.Valid {
_m.BillingMode = new(string)
*_m.BillingMode = value.String
}
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok { if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i]) return fmt.Errorf("unexpected type %T for field group_id", values[i])
@ -505,6 +541,26 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.ChannelID; v != nil {
builder.WriteString("channel_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.ModelMappingChain; v != nil {
builder.WriteString("model_mapping_chain=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.BillingTier; v != nil {
builder.WriteString("billing_tier=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.BillingMode; v != nil {
builder.WriteString("billing_mode=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil { if v := _m.GroupID; v != nil {
builder.WriteString("group_id=") builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))

View File

@ -28,6 +28,14 @@ const (
FieldRequestedModel = "requested_model" FieldRequestedModel = "requested_model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database. // FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model" FieldUpstreamModel = "upstream_model"
// FieldChannelID holds the string denoting the channel_id field in the database.
FieldChannelID = "channel_id"
// FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database.
FieldModelMappingChain = "model_mapping_chain"
// FieldBillingTier holds the string denoting the billing_tier field in the database.
FieldBillingTier = "billing_tier"
// FieldBillingMode holds the string denoting the billing_mode field in the database.
FieldBillingMode = "billing_mode"
// FieldGroupID holds the string denoting the group_id field in the database. // FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id" FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database. // FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@ -141,6 +149,10 @@ var Columns = []string{
FieldModel, FieldModel,
FieldRequestedModel, FieldRequestedModel,
FieldUpstreamModel, FieldUpstreamModel,
FieldChannelID,
FieldModelMappingChain,
FieldBillingTier,
FieldBillingMode,
FieldGroupID, FieldGroupID,
FieldSubscriptionID, FieldSubscriptionID,
FieldInputTokens, FieldInputTokens,
@ -189,6 +201,12 @@ var (
RequestedModelValidator func(string) error RequestedModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error UpstreamModelValidator func(string) error
// ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
ModelMappingChainValidator func(string) error
// BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
BillingTierValidator func(string) error
// BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
BillingModeValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field. // DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field. // DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@ -278,6 +296,26 @@ func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
} }
// ByChannelID orders the results by the channel_id field.
func ByChannelID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldChannelID, opts...).ToFunc()
}
// ByModelMappingChain orders the results by the model_mapping_chain field.
func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc()
}
// ByBillingTier orders the results by the billing_tier field.
func ByBillingTier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingTier, opts...).ToFunc()
}
// ByBillingMode orders the results by the billing_mode field.
func ByBillingMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingMode, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field. // ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption { func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc() return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@ -90,6 +90,26 @@ func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
} }
// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ.
func ChannelID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
}
// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ.
func ModelMappingChain(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
}
// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ.
func BillingTier(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
}
// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ.
func BillingMode(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog { func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@ -565,6 +585,281 @@ func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
} }
// ChannelIDEQ applies the EQ predicate on the "channel_id" field.
func ChannelIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
}
// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field.
func ChannelIDNEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v))
}
// ChannelIDIn applies the In predicate on the "channel_id" field.
func ChannelIDIn(vs ...int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...))
}
// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field.
func ChannelIDNotIn(vs ...int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...))
}
// ChannelIDGT applies the GT predicate on the "channel_id" field.
func ChannelIDGT(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldChannelID, v))
}
// ChannelIDGTE applies the GTE predicate on the "channel_id" field.
func ChannelIDGTE(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v))
}
// ChannelIDLT applies the LT predicate on the "channel_id" field.
func ChannelIDLT(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldChannelID, v))
}
// ChannelIDLTE applies the LTE predicate on the "channel_id" field.
func ChannelIDLTE(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v))
}
// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field.
func ChannelIDIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldChannelID))
}
// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field.
func ChannelIDNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldChannelID))
}
// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field.
func ModelMappingChainEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
}
// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field.
func ModelMappingChainNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v))
}
// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field.
func ModelMappingChainIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...))
}
// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field.
func ModelMappingChainNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...))
}
// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field.
func ModelMappingChainGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v))
}
// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field.
func ModelMappingChainGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v))
}
// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field.
func ModelMappingChainLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v))
}
// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field.
func ModelMappingChainLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v))
}
// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field.
func ModelMappingChainContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v))
}
// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field.
func ModelMappingChainHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v))
}
// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field.
func ModelMappingChainHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v))
}
// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field.
func ModelMappingChainIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain))
}
// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field.
func ModelMappingChainNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain))
}
// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field.
func ModelMappingChainEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v))
}
// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field.
func ModelMappingChainContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v))
}
// BillingTierEQ applies the EQ predicate on the "billing_tier" field.
func BillingTierEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
}
// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field.
func BillingTierNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v))
}
// BillingTierIn applies the In predicate on the "billing_tier" field.
func BillingTierIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...))
}
// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field.
func BillingTierNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...))
}
// BillingTierGT applies the GT predicate on the "billing_tier" field.
func BillingTierGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v))
}
// BillingTierGTE applies the GTE predicate on the "billing_tier" field.
func BillingTierGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v))
}
// BillingTierLT applies the LT predicate on the "billing_tier" field.
func BillingTierLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v))
}
// BillingTierLTE applies the LTE predicate on the "billing_tier" field.
func BillingTierLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v))
}
// BillingTierContains applies the Contains predicate on the "billing_tier" field.
func BillingTierContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v))
}
// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field.
func BillingTierHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v))
}
// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field.
func BillingTierHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v))
}
// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field.
func BillingTierIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier))
}
// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field.
func BillingTierNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier))
}
// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field.
func BillingTierEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v))
}
// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field.
func BillingTierContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v))
}
// BillingModeEQ applies the EQ predicate on the "billing_mode" field.
func BillingModeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
}
// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field.
func BillingModeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v))
}
// BillingModeIn applies the In predicate on the "billing_mode" field.
func BillingModeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...))
}
// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field.
func BillingModeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...))
}
// BillingModeGT applies the GT predicate on the "billing_mode" field.
func BillingModeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v))
}
// BillingModeGTE applies the GTE predicate on the "billing_mode" field.
func BillingModeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v))
}
// BillingModeLT applies the LT predicate on the "billing_mode" field.
func BillingModeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v))
}
// BillingModeLTE applies the LTE predicate on the "billing_mode" field.
func BillingModeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v))
}
// BillingModeContains applies the Contains predicate on the "billing_mode" field.
func BillingModeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v))
}
// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field.
func BillingModeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v))
}
// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field.
func BillingModeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v))
}
// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field.
func BillingModeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode))
}
// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field.
func BillingModeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode))
}
// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field.
func BillingModeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v))
}
// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field.
func BillingModeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field. // GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog { func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))

View File

@ -85,6 +85,62 @@ func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
return _c return _c
} }
// SetChannelID sets the "channel_id" field.
func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate {
_c.mutation.SetChannelID(v)
return _c
}
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate {
if v != nil {
_c.SetChannelID(*v)
}
return _c
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate {
_c.mutation.SetModelMappingChain(v)
return _c
}
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate {
if v != nil {
_c.SetModelMappingChain(*v)
}
return _c
}
// SetBillingTier sets the "billing_tier" field.
func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate {
_c.mutation.SetBillingTier(v)
return _c
}
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate {
if v != nil {
_c.SetBillingTier(*v)
}
return _c
}
// SetBillingMode sets the "billing_mode" field.
func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate {
_c.mutation.SetBillingMode(v)
return _c
}
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate {
if v != nil {
_c.SetBillingMode(*v)
}
return _c
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v) _c.mutation.SetGroupID(v)
@ -634,6 +690,21 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
} }
} }
if v, ok := _c.mutation.ModelMappingChain(); ok {
if err := usagelog.ModelMappingChainValidator(v); err != nil {
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
}
}
if v, ok := _c.mutation.BillingTier(); ok {
if err := usagelog.BillingTierValidator(v); err != nil {
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
}
}
if v, ok := _c.mutation.BillingMode(); ok {
if err := usagelog.BillingModeValidator(v); err != nil {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
if _, ok := _c.mutation.InputTokens(); !ok { if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
} }
@ -760,6 +831,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value _node.UpstreamModel = &value
} }
if value, ok := _c.mutation.ChannelID(); ok {
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
_node.ChannelID = &value
}
if value, ok := _c.mutation.ModelMappingChain(); ok {
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
_node.ModelMappingChain = &value
}
if value, ok := _c.mutation.BillingTier(); ok {
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
_node.BillingTier = &value
}
if value, ok := _c.mutation.BillingMode(); ok {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
_node.BillingMode = &value
}
if value, ok := _c.mutation.InputTokens(); ok { if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value _node.InputTokens = value
@ -1093,6 +1180,84 @@ func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
return u return u
} }
// SetChannelID sets the "channel_id" field.
func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldChannelID, v)
return u
}
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldChannelID)
return u
}
// AddChannelID adds v to the "channel_id" field.
func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert {
u.Add(usagelog.FieldChannelID, v)
return u
}
// ClearChannelID clears the value of the "channel_id" field.
func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert {
u.SetNull(usagelog.FieldChannelID)
return u
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert {
u.Set(usagelog.FieldModelMappingChain, v)
return u
}
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldModelMappingChain)
return u
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert {
u.SetNull(usagelog.FieldModelMappingChain)
return u
}
// SetBillingTier sets the "billing_tier" field.
func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert {
u.Set(usagelog.FieldBillingTier, v)
return u
}
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldBillingTier)
return u
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert {
u.SetNull(usagelog.FieldBillingTier)
return u
}
// SetBillingMode sets the "billing_mode" field.
func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert {
u.Set(usagelog.FieldBillingMode, v)
return u
}
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldBillingMode)
return u
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert {
u.SetNull(usagelog.FieldBillingMode)
return u
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v) u.Set(usagelog.FieldGroupID, v)
@ -1724,6 +1889,97 @@ func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
}) })
} }
// SetChannelID sets the "channel_id" field.
func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetChannelID(v)
})
}
// AddChannelID adds v to the "channel_id" field.
func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.AddChannelID(v)
})
}
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateChannelID()
})
}
// ClearChannelID clears the value of the "channel_id" field.
func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearChannelID()
})
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetModelMappingChain(v)
})
}
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateModelMappingChain()
})
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearModelMappingChain()
})
}
// SetBillingTier sets the "billing_tier" field.
func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingTier(v)
})
}
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingTier()
})
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingTier()
})
}
// SetBillingMode sets the "billing_mode" field.
func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingMode(v)
})
}
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingMode()
})
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingMode()
})
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
@ -2600,6 +2856,97 @@ func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
}) })
} }
// SetChannelID sets the "channel_id" field.
func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetChannelID(v)
})
}
// AddChannelID adds v to the "channel_id" field.
func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.AddChannelID(v)
})
}
// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateChannelID()
})
}
// ClearChannelID clears the value of the "channel_id" field.
func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearChannelID()
})
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetModelMappingChain(v)
})
}
// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateModelMappingChain()
})
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearModelMappingChain()
})
}
// SetBillingTier sets the "billing_tier" field.
func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingTier(v)
})
}
// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingTier()
})
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingTier()
})
}
// SetBillingMode sets the "billing_mode" field.
func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetBillingMode(v)
})
}
// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateBillingMode()
})
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearBillingMode()
})
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {

View File

@ -142,6 +142,93 @@ func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
return _u return _u
} }
// SetChannelID sets the "channel_id" field.
func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate {
_u.mutation.ResetChannelID()
_u.mutation.SetChannelID(v)
return _u
}
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate {
if v != nil {
_u.SetChannelID(*v)
}
return _u
}
// AddChannelID adds value to the "channel_id" field.
func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate {
_u.mutation.AddChannelID(v)
return _u
}
// ClearChannelID clears the value of the "channel_id" field.
func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate {
_u.mutation.ClearChannelID()
return _u
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate {
_u.mutation.SetModelMappingChain(v)
return _u
}
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate {
if v != nil {
_u.SetModelMappingChain(*v)
}
return _u
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate {
_u.mutation.ClearModelMappingChain()
return _u
}
// SetBillingTier sets the "billing_tier" field.
func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate {
_u.mutation.SetBillingTier(v)
return _u
}
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate {
if v != nil {
_u.SetBillingTier(*v)
}
return _u
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate {
_u.mutation.ClearBillingTier()
return _u
}
// SetBillingMode sets the "billing_mode" field.
func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate {
_u.mutation.SetBillingMode(v)
return _u
}
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate {
if v != nil {
_u.SetBillingMode(*v)
}
return _u
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate {
_u.mutation.ClearBillingMode()
return _u
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v) _u.mutation.SetGroupID(v)
@ -795,6 +882,21 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
} }
} }
if v, ok := _u.mutation.ModelMappingChain(); ok {
if err := usagelog.ModelMappingChainValidator(v); err != nil {
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
}
}
if v, ok := _u.mutation.BillingTier(); ok {
if err := usagelog.BillingTierValidator(v); err != nil {
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
}
}
if v, ok := _u.mutation.BillingMode(); ok {
if err := usagelog.BillingModeValidator(v); err != nil {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok { if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil { if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@ -857,6 +959,33 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UpstreamModelCleared() { if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
} }
if value, ok := _u.mutation.ChannelID(); ok {
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedChannelID(); ok {
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if _u.mutation.ChannelIDCleared() {
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelMappingChain(); ok {
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
}
if _u.mutation.ModelMappingChainCleared() {
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
}
if value, ok := _u.mutation.BillingTier(); ok {
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
}
if _u.mutation.BillingTierCleared() {
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
}
if value, ok := _u.mutation.BillingMode(); ok {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
}
if _u.mutation.BillingModeCleared() {
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok { if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
} }
@ -1279,6 +1408,93 @@ func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
return _u return _u
} }
// SetChannelID sets the "channel_id" field.
func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne {
_u.mutation.ResetChannelID()
_u.mutation.SetChannelID(v)
return _u
}
// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne {
if v != nil {
_u.SetChannelID(*v)
}
return _u
}
// AddChannelID adds value to the "channel_id" field.
func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne {
_u.mutation.AddChannelID(v)
return _u
}
// ClearChannelID clears the value of the "channel_id" field.
func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne {
_u.mutation.ClearChannelID()
return _u
}
// SetModelMappingChain sets the "model_mapping_chain" field.
func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne {
_u.mutation.SetModelMappingChain(v)
return _u
}
// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetModelMappingChain(*v)
}
return _u
}
// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne {
_u.mutation.ClearModelMappingChain()
return _u
}
// SetBillingTier sets the "billing_tier" field.
func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne {
_u.mutation.SetBillingTier(v)
return _u
}
// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetBillingTier(*v)
}
return _u
}
// ClearBillingTier clears the value of the "billing_tier" field.
func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne {
_u.mutation.ClearBillingTier()
return _u
}
// SetBillingMode sets the "billing_mode" field.
func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne {
_u.mutation.SetBillingMode(v)
return _u
}
// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetBillingMode(*v)
}
return _u
}
// ClearBillingMode clears the value of the "billing_mode" field.
func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne {
_u.mutation.ClearBillingMode()
return _u
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v) _u.mutation.SetGroupID(v)
@ -1945,6 +2161,21 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
} }
} }
if v, ok := _u.mutation.ModelMappingChain(); ok {
if err := usagelog.ModelMappingChainValidator(v); err != nil {
return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
}
}
if v, ok := _u.mutation.BillingTier(); ok {
if err := usagelog.BillingTierValidator(v); err != nil {
return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
}
}
if v, ok := _u.mutation.BillingMode(); ok {
if err := usagelog.BillingModeValidator(v); err != nil {
return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok { if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil { if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@ -2024,6 +2255,33 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UpstreamModelCleared() { if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
} }
if value, ok := _u.mutation.ChannelID(); ok {
_spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if value, ok := _u.mutation.AddedChannelID(); ok {
_spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
}
if _u.mutation.ChannelIDCleared() {
_spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
}
if value, ok := _u.mutation.ModelMappingChain(); ok {
_spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
}
if _u.mutation.ModelMappingChainCleared() {
_spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
}
if value, ok := _u.mutation.BillingTier(); ok {
_spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
}
if _u.mutation.BillingTierCleared() {
_spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
}
if value, ok := _u.mutation.BillingMode(); ok {
_spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
}
if _u.mutation.BillingModeCleared() {
_spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok { if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
} }

View File

@ -0,0 +1,452 @@
package admin
import (
"errors"
"fmt"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChannelHandler handles admin channel management
type ChannelHandler struct {
channelService *service.ChannelService
billingService *service.BillingService
}
// NewChannelHandler creates a new admin channel handler
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
return &ChannelHandler{channelService: channelService, billingService: billingService}
}
// --- Request / Response types ---
type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
}
type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
}
type channelModelPricingRequest struct {
Platform string `json:"platform" binding:"omitempty,max=50"`
Models []string `json:"models" binding:"required,min=1,max=100"`
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
Intervals []pricingIntervalRequest `json:"intervals"`
}
type pricingIntervalRequest struct {
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
SortOrder int `json:"sort_order"`
}
type channelResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
ID int64 `json:"id"`
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []pricingIntervalResponse `json:"intervals"`
}
type pricingIntervalResponse struct {
ID int64 `json:"id"`
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
SortOrder int `json:"sort_order"`
}
func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil {
return nil
}
resp := &channelResponse{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
RestrictModels: ch.RestrictModels,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = service.BillingModelSourceChannelMapped
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
if resp.ModelMapping == nil {
resp.ModelMapping = map[string]map[string]string{}
}
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
}
return resp
}
func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse {
models := p.Models
if models == nil {
models = []string{}
}
billingMode := string(p.BillingMode)
if billingMode == "" {
billingMode = string(service.BillingModeToken)
}
platform := p.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
for _, iv := range p.Intervals {
intervals = append(intervals, intervalToResponse(iv))
}
return channelModelPricingResponse{
ID: p.ID,
Platform: platform,
Models: models,
BillingMode: billingMode,
InputPrice: p.InputPrice,
OutputPrice: p.OutputPrice,
CacheWritePrice: p.CacheWritePrice,
CacheReadPrice: p.CacheReadPrice,
ImageOutputPrice: p.ImageOutputPrice,
PerRequestPrice: p.PerRequestPrice,
Intervals: intervals,
}
}
func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse {
return pricingIntervalResponse{
ID: iv.ID,
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
SortOrder: iv.SortOrder,
}
}
func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
result := make([]service.ChannelModelPricing, 0, len(reqs))
for _, r := range reqs {
billingMode := service.BillingMode(r.BillingMode)
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := r.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
SortOrder: iv.SortOrder,
})
}
result = append(result, service.ChannelModelPricing{
Platform: platform,
Models: r.Models,
BillingMode: billingMode,
InputPrice: r.InputPrice,
OutputPrice: r.OutputPrice,
CacheWritePrice: r.CacheWritePrice,
CacheReadPrice: r.CacheReadPrice,
ImageOutputPrice: r.ImageOutputPrice,
PerRequestPrice: r.PerRequestPrice,
Intervals: intervals,
})
}
return result
}
// validatePricingBillingMode 校验计费配置
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
for _, p := range pricing {
// 按次/图片模式必须配置默认价格或区间
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return errors.New("per-request price or intervals required for per_request/image billing mode")
}
}
// 校验价格不能为负
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
return err
}
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
return err
}
// 校验 interval至少有一个价格字段非空
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
}
}
}
return nil
}
func validatePriceNotNegative(field string, val *float64) error {
if val != nil && *val < 0 {
return fmt.Errorf("%s must be >= 0", field)
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- Handlers ---
// List handles listing channels with pagination
// GET /api/v1/admin/channels
func (h *ChannelHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*channelResponse, 0, len(channels))
for i := range channels {
out = append(out, channelToResponse(&channels[i]))
}
response.Paginated(c, out, pag.Total, page, pageSize)
}
// GetByID handles getting a channel by ID
// GET /api/v1/admin/channels/:id
func (h *ChannelHandler) GetByID(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
return
}
channel, err := h.channelService.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelToResponse(channel))
}
// Create handles creating a new channel
// POST /api/v1/admin/channels
func (h *ChannelHandler) Create(c *gin.Context) {
var req createChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
pricing := pricingRequestToService(req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name,
Description: req.Description,
GroupIDs: req.GroupIDs,
ModelPricing: pricing,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelToResponse(channel))
}
// Update handles updating a channel
// PUT /api/v1/admin/channels/:id
func (h *ChannelHandler) Update(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
return
}
var req updateChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input := &service.UpdateChannelInput{
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input.ModelPricing = &pricing
}
channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelToResponse(channel))
}
// Delete handles deleting a channel
// DELETE /api/v1/admin/channels/:id
func (h *ChannelHandler) Delete(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
return
}
if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Channel deleted successfully"})
}
// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充)
// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4
func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
model := strings.TrimSpace(c.Query("model"))
if model == "" {
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required").
WithMetadata(map[string]string{"param": "model"}))
return
}
pricing, err := h.billingService.GetModelPricing(model)
if err != nil {
// 模型不在定价列表中
response.Success(c, gin.H{"found": false})
return
}
response.Success(c, gin.H{
"found": true,
"input_price": pricing.InputPricePerToken,
"output_price": pricing.OutputPricePerToken,
"cache_write_price": pricing.CacheCreationPricePerToken,
"cache_read_price": pricing.CacheReadPricePerToken,
"image_output_price": pricing.ImageOutputPricePerToken,
})
}

View File

@ -0,0 +1,502 @@
//go:build unit
package admin
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
func float64Ptr(v float64) *float64 { return &v }
func intPtr(v int) *int { return &v }
// ---------------------------------------------------------------------------
// 1. channelToResponse
// ---------------------------------------------------------------------------
func TestChannelToResponse_NilInput(t *testing.T) {
require.Nil(t, channelToResponse(nil))
}
func TestChannelToResponse_FullChannel(t *testing.T) {
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 42,
Name: "test-channel",
Description: "desc",
Status: "active",
BillingModelSource: "upstream",
RestrictModels: true,
CreatedAt: now,
UpdatedAt: now.Add(time.Hour),
GroupIDs: []int64{1, 2, 3},
ModelPricing: []service.ChannelModelPricing{
{
ID: 10,
Platform: "openai",
Models: []string{"gpt-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
PerRequestPrice: float64Ptr(0.5),
},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
},
}
resp := channelToResponse(ch)
require.NotNil(t, resp)
require.Equal(t, int64(42), resp.ID)
require.Equal(t, "test-channel", resp.Name)
require.Equal(t, "desc", resp.Description)
require.Equal(t, "active", resp.Status)
require.Equal(t, "upstream", resp.BillingModelSource)
require.True(t, resp.RestrictModels)
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
// model mapping
require.Len(t, resp.ModelMapping, 1)
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
// pricing
require.Len(t, resp.ModelPricing, 1)
p := resp.ModelPricing[0]
require.Equal(t, int64(10), p.ID)
require.Equal(t, "openai", p.Platform)
require.Equal(t, []string{"gpt-4"}, p.Models)
require.Equal(t, "token", p.BillingMode)
require.Equal(t, float64Ptr(0.01), p.InputPrice)
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
require.Empty(t, p.Intervals)
}
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
ModelMapping: nil,
ModelPricing: []service.ChannelModelPricing{
{
Platform: "",
BillingMode: "",
Models: []string{"m1"},
},
},
}
resp := channelToResponse(ch)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping)
require.Empty(t, resp.ModelMapping)
require.Len(t, resp.ModelPricing, 1)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: nil,
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
require.NotNil(t, resp.ModelPricing[0].Models)
require.Empty(t, resp.ModelPricing[0].Models)
}
func TestChannelToResponse_WithIntervals(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: []string{"m1"},
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{
ID: 100,
MinTokens: 0,
MaxTokens: intPtr(1000),
TierLabel: "1K",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
ID: 101,
MinTokens: 1000,
MaxTokens: nil,
TierLabel: "unlimited",
SortOrder: 2,
},
},
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
intervals := resp.ModelPricing[0].Intervals
require.Len(t, intervals, 2)
iv0 := intervals[0]
require.Equal(t, int64(100), iv0.ID)
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(1000), iv0.MaxTokens)
require.Equal(t, "1K", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := intervals[1]
require.Equal(t, int64(101), iv1.ID)
require.Equal(t, 1000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "unlimited", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestChannelToResponse_MultipleEntries(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "multi",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
ID: 1,
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.003),
OutputPrice: float64Ptr(0.015),
},
{
ID: 2,
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(1.0),
},
{
ID: 3,
Platform: "gemini",
Models: []string{"gemini-2.5-pro"},
BillingMode: service.BillingModeImage,
ImageOutputPrice: float64Ptr(0.05),
PerRequestPrice: float64Ptr(0.2),
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 3)
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
}
// ---------------------------------------------------------------------------
// 2. pricingRequestToService
// ---------------------------------------------------------------------------
func TestPricingRequestToService_Defaults(t *testing.T) {
tests := []struct {
name string
req channelModelPricingRequest
wantField string // which default field to check
wantValue string
}{
{
name: "empty billing mode defaults to token",
req: channelModelPricingRequest{
Models: []string{"m1"},
BillingMode: "",
},
wantField: "BillingMode",
wantValue: string(service.BillingModeToken),
},
{
name: "empty platform defaults to anthropic",
req: channelModelPricingRequest{
Models: []string{"m1"},
Platform: "",
},
wantField: "Platform",
wantValue: "anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
require.Len(t, result, 1)
switch tt.wantField {
case "BillingMode":
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
case "Platform":
require.Equal(t, tt.wantValue, result[0].Platform)
}
})
}
}
func TestPricingRequestToService_WithAllFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: "per_request",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
ImageOutputPrice: float64Ptr(0.04),
PerRequestPrice: float64Ptr(0.5),
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Equal(t, "openai", r.Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
require.Equal(t, float64Ptr(0.01), r.InputPrice)
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
}
func TestPricingRequestToService_WithIntervals(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "per_request",
Intervals: []pricingIntervalRequest{
{
MinTokens: 0,
MaxTokens: intPtr(2000),
TierLabel: "small",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
MinTokens: 2000,
MaxTokens: nil,
TierLabel: "large",
SortOrder: 2,
},
},
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
require.Len(t, result[0].Intervals, 2)
iv0 := result[0].Intervals[0]
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(2000), iv0.MaxTokens)
require.Equal(t, "small", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := result[0].Intervals[1]
require.Equal(t, 2000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "large", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestPricingRequestToService_EmptySlice(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{})
require.NotNil(t, result)
require.Empty(t, result)
}
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "token",
// all price fields are nil by default
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Nil(t, r.InputPrice)
require.Nil(t, r.OutputPrice)
require.Nil(t, r.CacheWritePrice)
require.Nil(t, r.CacheReadPrice)
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
dim.Endpoint = c.Query("endpoint") dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
// Additional filter conditions
if v := c.Query("user_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.UserID = id
}
}
if v := c.Query("api_key_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.APIKeyID = id
}
}
if v := c.Query("account_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.AccountID = id
}
}
if v := c.Query("request_type"); v != "" {
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
rtVal := int16(rt)
dim.RequestType = &rtVal
}
}
if v := c.Query("stream"); v != "" {
if s, err := strconv.ParseBool(v); err == nil {
dim.Stream = &s
}
}
if v := c.Query("billing_type"); v != "" {
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
btVal := int8(bt)
dim.BillingType = &btVal
}
}
limit := 50 limit := 50
if v := c.Query("limit"); v != "" { if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {

View File

@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
model := c.Query("model") model := c.Query("model")
billingMode := strings.TrimSpace(c.Query("billing_mode"))
var requestType *int16 var requestType *int16
var stream *bool var stream *bool
@ -174,6 +175,7 @@ func (h *UsageHandler) List(c *gin.Context) {
RequestType: requestType, RequestType: requestType,
Stream: stream, Stream: stream,
BillingType: billingType, BillingType: billingType,
BillingMode: billingMode,
StartTime: startTime, StartTime: startTime,
EndTime: endTime, EndTime: endTime,
ExactTotal: exactTotal, ExactTotal: exactTotal,
@ -234,6 +236,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
} }
model := c.Query("model") model := c.Query("model")
billingMode := strings.TrimSpace(c.Query("billing_mode"))
var requestType *int16 var requestType *int16
var stream *bool var stream *bool
@ -312,6 +315,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
RequestType: requestType, RequestType: requestType,
Stream: stream, Stream: stream,
BillingType: billingType, BillingType: billingType,
BillingMode: billingMode,
StartTime: &startTime, StartTime: &startTime,
EndTime: &endTime, EndTime: &endTime,
} }

View File

@ -577,6 +577,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
MediaType: l.MediaType, MediaType: l.MediaType,
UserAgent: l.UserAgent, UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden, CacheTTLOverridden: l.CacheTTLOverridden,
BillingMode: l.BillingMode,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey), APIKey: APIKeyFromService(l.APIKey),
@ -604,6 +605,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
return &AdminUsageLog{ return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l), UsageLog: usageLogFromServiceUser(l),
UpstreamModel: l.UpstreamModel, UpstreamModel: l.UpstreamModel,
ChannelID: l.ChannelID,
ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),

View File

@ -390,6 +390,9 @@ type UsageLog struct {
// Cache TTL Override 标记 // Cache TTL Override 标记
CacheTTLOverridden bool `json:"cache_ttl_overridden"` CacheTTLOverridden bool `json:"cache_ttl_overridden"`
// BillingMode 计费模式token/image
BillingMode *string `json:"billing_mode,omitempty"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"` User *User `json:"user,omitempty"`
@ -406,6 +409,13 @@ type AdminUsageLog struct {
// Omitted when no mapping was applied (requested model was used as-is). // Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"` UpstreamModel *string `json:"upstream_model,omitempty"`
// ChannelID 渠道 ID
ChannelID *int64 `json:"channel_id,omitempty"`
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
// BillingTier 计费层级标签per_request/image 模式)
BillingTier *string `json:"billing_tier,omitempty"`
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`

View File

@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
@ -292,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@ -478,6 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gateway.messages"), zap.String("component", "handler.gateway.messages"),
@ -514,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@ -660,6 +664,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
parsedReq.OnUpstreamAccepted = queueRelease parsedReq.OnUpstreamAccepted = queueRelease
// ===== 用户消息串行队列 END ===== // ===== 用户消息串行队列 END =====
// 应用渠道模型映射到请求
if channelMapping.Mapped {
parsedReq.Model = channelMapping.MappedModel
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
@ -810,6 +821,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gateway.messages"), zap.String("component", "handler.gateway.messages"),

View File

@ -80,6 +80,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction // Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
@ -154,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
fs := NewFailoverState(h.maxAccountSwitches, false) fs := NewFailoverState(h.maxAccountSwitches, false)
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
@ -203,7 +206,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 5. Forward request // 5. Forward request
writerSizeBeforeForward := c.Writer.Size() writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
@ -255,6 +262,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("gateway.cc.record_usage_failed", reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),

View File

@ -80,6 +80,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction: // Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint. // /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected. // When claude_code_only is enabled, this endpoint is rejected.
@ -159,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
fs := NewFailoverState(h.maxAccountSwitches, false) fs := NewFailoverState(h.maxAccountSwitches, false)
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
@ -208,7 +211,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 5. Forward request // 5. Forward request
writerSizeBeforeForward := c.Writer.Size() writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
@ -261,6 +268,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("gateway.responses.record_usage_failed", reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),

View File

@ -161,6 +161,8 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // digestStore nil, // digestStore
nil, // settingService nil, // settingService
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
) )
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。 // RunModeSimple跳过计费检查避免引入 repo/cache 依赖。

View File

@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body) setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
modelName = channelMapping.MappedModel
}
// Get subscription (may be nil) // Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"), zap.String("component", "handler.gemini_v1beta.models"),

View File

@ -30,6 +30,7 @@ type AdminHandlers struct {
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
APIKey *admin.AdminAPIKeyHandler APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler ScheduledTest *admin.ScheduledTestHandler
Channel *admin.ChannelHandler
} }
// Handlers contains all HTTP handlers // Handlers contains all HTTP handlers

View File

@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
} }
@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
@ -257,16 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c), InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"), zap.String("component", "handler.openai_gateway.chat_completions"),

View File

@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) { if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return return
@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Forward request // Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) // 应用渠道模型映射到请求体
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.responses"), zap.String("component", "handler.openai_gateway.responses"),
@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) // 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.messages"), zap.String("component", "handler.openai_gateway.messages"),
@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
releaseTurnSlots := func() { releaseTurnSlots := func() {
@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("openai.websocket_record_usage_failed", reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}, },
} }
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { // 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err) closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed", reqLog.Warn("openai.websocket_proxy_failed",

View File

@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga
return service.NewGatewayService( return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
) )
} }

View File

@ -30,6 +30,8 @@ import (
) )
// SoraGatewayHandler handles Sora chat completions requests // SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除不集成渠道Channel功能。
type SoraGatewayHandler struct { type SoraGatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService soraGatewayService *service.SoraGatewayService
@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
var lastFailoverHeaders http.Header var lastFailoverHeaders http.Header
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
if err != nil { if err != nil {
reqLog.Warn("sora.account_select_failed", reqLog.Warn("sora.account_select_failed",
zap.Error(err), zap.Error(err),

View File

@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil, // digestStore nil, // digestStore
nil, // settingService nil, // settingService
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
) )
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}

View File

@ -33,6 +33,7 @@ func ProvideAdminHandlers(
tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler,
apiKeyHandler *admin.AdminAPIKeyHandler, apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler, scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
) *AdminHandlers { ) *AdminHandlers {
return &AdminHandlers{ return &AdminHandlers{
Dashboard: dashboardHandler, Dashboard: dashboardHandler,
@ -59,6 +60,7 @@ func ProvideAdminHandlers(
TLSFingerprintProfile: tlsFingerprintProfileHandler, TLSFingerprintProfile: tlsFingerprintProfileHandler,
APIKey: apiKeyHandler, APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler, ScheduledTest: scheduledTestHandler,
Channel: channelHandler,
} }
} }
@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet(
admin.NewTLSFingerprintProfileHandler, admin.NewTLSFingerprintProfileHandler,
admin.NewAdminAPIKeyHandler, admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler, admin.NewScheduledTestHandler,
admin.NewChannelHandler,
// AdminHandlers and Handlers constructors // AdminHandlers and Handlers constructors
ProvideAdminHandlers, ProvideAdminHandlers,

View File

@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ClaudeError Claude 错误响应 // ClaudeError Claude 错误响应

View File

@ -149,13 +149,31 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
} }
// GeminiTokenDetail Gemini token 详情(按模态分类)
type GeminiTokenDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// GeminiUsageMetadata Gemini 用量元数据 // GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct { type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"` PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens按输出价格计费 ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens按输出价格计费
CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
}
// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
func (m *GeminiUsageMetadata) ImageOutputTokens() int {
for _, d := range m.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
return d.TokenCount
}
}
return 0
} }
// GeminiGroundingMetadata Gemini grounding 元数据Web Search // GeminiGroundingMetadata Gemini grounding 元数据Web Search

View File

@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 生成响应 ID // 生成响应 ID

View File

@ -32,9 +32,10 @@ type StreamingProcessor struct {
groundingChunks []GeminiGroundingChunk groundingChunks []GeminiGroundingChunk
// 累计 usage // 累计 usage
inputTokens int inputTokens int
outputTokens int outputTokens int
cacheReadTokens int cacheReadTokens int
imageOutputTokens int
} }
// NewStreamingProcessor 创建流式响应处理器 // NewStreamingProcessor 创建流式响应处理器
@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached p.cacheReadTokens = cached
p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
} }
// 处理 parts // 处理 parts
@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
if !p.messageStartSent { if !p.messageStartSent {
@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached usage.CacheReadInputTokens = cached
usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
} }
responseID := v1Resp.ResponseID responseID := v1Resp.ResponseID
@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens, InputTokens: p.inputTokens,
OutputTokens: p.outputTokens, OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens, CacheReadInputTokens: p.cacheReadTokens,
ImageOutputTokens: p.imageOutputTokens,
} }
deltaEvent := map[string]any{ deltaEvent := map[string]any{

View File

@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
ModelType string // "requested", "upstream", or "mapping" ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable) Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path" EndpointType string // "inbound", "upstream", or "path"
// Additional filter conditions
UserID int64 // filter by user_id (>0 to enable)
APIKeyID int64 // filter by api_key_id (>0 to enable)
AccountID int64 // filter by account_id (>0 to enable)
RequestType *int16 // filter by request_type (non-nil to enable)
Stream *bool // filter by stream flag (non-nil to enable)
BillingType *int8 // filter by billing_type (non-nil to enable)
} }
// APIKeyUsageTrendPoint represents API key usage trend data point // APIKeyUsageTrendPoint represents API key usage trend data point
@ -230,6 +237,7 @@ type UsageLogFilters struct {
RequestType *int16 RequestType *int16
Stream *bool Stream *bool
BillingType *int8 BillingType *int8
BillingMode string
StartTime *time.Time StartTime *time.Time
EndTime *time.Time EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.

View File

@ -0,0 +1,461 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type channelRepository struct {
db *sql.DB
}
// NewChannelRepository 创建渠道数据访问实例
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
return &channelRepository{db: db}
}
// runInTx 在事务中执行 fn成功 commit失败 rollback。
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("insert channel: %w", err)
}
// 设置分组关联
if len(channel.GroupIDs) > 0 {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 设置模型定价
if len(channel.ModelPricing) > 0 {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
return nil, err
}
ch.GroupIDs = groupIDs
pricing, err := r.ListModelPricing(ctx, id)
if err != nil {
return nil, err
}
ch.ModelPricing = pricing
return ch, nil
}
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $7`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("update channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
// 更新分组关联
if channel.GroupIDs != nil {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 更新模型定价
if channel.ModelPricing != nil {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
return nil
}
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
where := []string{"1=1"}
args := []any{}
argIdx := 1
if status != "" {
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
args = append(args, status)
argIdx++
}
if search != "" {
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
args = append(args, "%"+escapeLike(search)+"%")
argIdx++
}
whereClause := strings.Join(where, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count channels: %w", err)
}
pageSize := params.Limit() // 约束在 [1, 100]
page := params.Page
if page < 1 {
page = 1
}
offset := (page - 1) * pageSize
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
if err != nil {
return nil, nil, fmt.Errorf("query channels: %w", err)
}
defer func() { _ = rows.Close() }()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate channels: %w", err)
}
// 批量加载分组 ID 和模型定价(避免 N+1
if len(channelIDs) > 0 {
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
}
pages := 0
if total > 0 {
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
}
paginationResult := &pagination.PaginationResult{
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
}
return channels, paginationResult, nil
}
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
}
defer func() { _ = rows.Close() }()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate channels: %w", err)
}
if len(channelIDs) == 0 {
return channels, nil
}
// 批量加载分组 ID
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, err
}
// 批量加载模型定价
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
return channels, nil
}
// --- 批量加载辅助方法 ---
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT channel_id, group_id FROM channel_groups
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load group ids: %w", err)
}
defer func() { _ = rows.Close() }()
groupMap := make(map[int64][]int64, len(channelIDs))
for rows.Next() {
var channelID, groupID int64
if err := rows.Scan(&channelID, &groupID); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
groupMap[channelID] = append(groupMap[channelID], groupID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return groupMap, nil
}
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
).Scan(&exists)
return exists, err
}
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
).Scan(&exists)
return exists, err
}
// --- 分组关联 ---
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("get group ids: %w", err)
}
defer func() { _ = rows.Close() }()
var ids []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return ids, nil
}
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
}
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
var channelID int64
err := r.db.QueryRowContext(ctx,
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
).Scan(&channelID)
if err == sql.ErrNoRows {
return 0, nil
}
return channelID, err
}
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
pq.Array(groupIDs), channelID,
)
if err != nil {
return nil, fmt.Errorf("get groups in other channels: %w", err)
}
defer func() { _ = rows.Close() }()
var conflicting []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan conflicting group id: %w", err)
}
conflicting = append(conflicting, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
}
return conflicting, nil
}
// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
// 格式:{"platform": {"src": "dst"}, ...}
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal model_mapping: %w", err)
}
return data, nil
}
// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
func unmarshalModelMapping(data []byte) map[string]map[string]string {
if len(data) == 0 {
return nil
}
var m map[string]map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
return make(map[int64]string), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
pq.Array(groupIDs),
)
if err != nil {
return nil, fmt.Errorf("get group platforms: %w", err)
}
defer rows.Close() //nolint:errcheck
result := make(map[int64]string, len(groupIDs))
for rows.Next() {
var id int64
var platform string
if err := rows.Scan(&id, &platform); err != nil {
return nil, fmt.Errorf("scan group platform: %w", err)
}
result[id] = platform
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group platforms: %w", err)
}
return result, nil
}

View File

@ -0,0 +1,291 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 模型定价 ---
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("list model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
result, pricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
if len(pricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
if err != nil {
return nil, err
}
for i := range result {
result[i].Intervals = intervalMap[result[i].ID]
}
}
return result, nil
}
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
return createModelPricingExec(ctx, r.db, pricing)
}
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
result, err := r.db.ExecContext(ctx,
`UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
WHERE id = $10`,
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
)
if err != nil {
return fmt.Errorf("update model pricing: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
}
return nil
}
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete model pricing: %w", err)
}
return nil
}
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
})
}
// --- 批量加载辅助方法 ---
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
// 按 channelID 分组
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
for _, p := range allPricing {
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
}
// 批量加载所有区间
if len(allPricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for chID := range pricingMap {
for i := range pricingMap[chID] {
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
}
}
}
return pricingMap, nil
}
// batchLoadIntervals 批量加载多个定价条目的区间
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load intervals: %w", err)
}
defer func() { _ = rows.Close() }()
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan interval: %w", err)
}
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate intervals: %w", err)
}
return intervalMap, nil
}
// --- 共享 scan 辅助 ---
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
var result []service.ChannelModelPricing
var pricingIDs []int64
for rows.Next() {
var p service.ChannelModelPricing
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingIDs = append(pricingIDs, p.ID)
result = append(result, p)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
}
return result, pricingIDs, nil
}
// --- 事务内辅助方法 ---
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
type dbExec interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old group associations: %w", err)
}
if len(groupIDs) == 0 {
return nil
}
_, err := exec.ExecContext(ctx,
`INSERT INTO channel_groups (channel_id, group_id)
SELECT $1, unnest($2::bigint[])`,
channelID, pq.Array(groupIDs),
)
if err != nil {
return fmt.Errorf("insert group associations: %w", err)
}
return nil
}
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
if platform == "" {
platform = "anthropic"
}
err = exec.QueryRowContext(ctx,
`INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
pricing.ChannelID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
for i := range pricing.Intervals {
pricing.Intervals[i].PricingID = pricing.ID
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
return err
}
}
return nil
}
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
return exec.QueryRowContext(ctx,
`INSERT INTO channel_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old model pricing: %w", err)
}
for i := range pricingList {
pricingList[i].ChannelID = channelID
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
}
return nil
}
// isUniqueViolation 检查 pq 唯一约束违反错误
func isUniqueViolation(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr != nil {
return pqErr.Code == "23505"
}
return false
}
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
func escapeLike(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}

View File

@ -0,0 +1,227 @@
//go:build unit
package repository
import (
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
// --- marshalModelMapping ---
func TestMarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input map[string]map[string]string
wantJSON string // expected JSON output (exact match)
}{
{
name: "empty map",
input: map[string]map[string]string{},
wantJSON: "{}",
},
{
name: "nil map",
input: nil,
wantJSON: "{}",
},
{
name: "populated map",
input: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
},
{
name: "nested values",
input: map[string]map[string]string{
"openai": {"*": "gpt-5.4"},
"anthropic": {"claude-old": "claude-new"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := marshalModelMapping(tt.input)
require.NoError(t, err)
if tt.wantJSON != "" {
require.Equal(t, []byte(tt.wantJSON), result)
} else {
// round-trip: unmarshal and compare with input
var parsed map[string]map[string]string
require.NoError(t, json.Unmarshal(result, &parsed))
require.Equal(t, tt.input, parsed)
}
})
}
}
// --- unmarshalModelMapping ---
func TestUnmarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input []byte
wantNil bool
want map[string]map[string]string
}{
{
name: "nil data",
input: nil,
wantNil: true,
},
{
name: "empty data",
input: []byte{},
wantNil: true,
},
{
name: "invalid JSON",
input: []byte("not-json"),
wantNil: true,
},
{
name: "type error - number",
input: []byte("42"),
wantNil: true,
},
{
name: "type error - array",
input: []byte("[1,2,3]"),
wantNil: true,
},
{
name: "valid JSON",
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
want: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
"anthropic": {"old": "new"},
},
},
{
name: "empty object",
input: []byte("{}"),
want: map[string]map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := unmarshalModelMapping(tt.input)
if tt.wantNil {
require.Nil(t, result)
} else {
require.NotNil(t, result)
require.Equal(t, tt.want, result)
}
})
}
}
// --- escapeLike ---
func TestEscapeLike(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "no special chars",
input: "hello",
want: "hello",
},
{
name: "backslash",
input: `a\b`,
want: `a\\b`,
},
{
name: "percent",
input: "50%",
want: `50\%`,
},
{
name: "underscore",
input: "a_b",
want: `a\_b`,
},
{
name: "all special chars",
input: `a\b%c_d`,
want: `a\\b\%c\_d`,
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "consecutive special chars",
input: "%_%",
want: `\%\_\%`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, escapeLike(tt.input))
})
}
}
// --- isUniqueViolation ---
func TestIsUniqueViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "unique violation code 23505",
err: &pq.Error{Code: "23505"},
want: true,
},
{
name: "different pq error code",
err: &pq.Error{Code: "23503"},
want: false,
},
{
name: "non-pq error",
err: errors.New("some generic error"),
want: false,
},
{
name: "typed nil pq.Error",
err: func() error {
var pqErr *pq.Error
return pqErr
}(),
want: false,
},
{
name: "bare nil",
err: nil,
want: false,
},
{
name: "wrapped pq error with 23505",
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, isUniqueViolation(tt.err))
})
}
}

View File

@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens "integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens "integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens "integer", // cache_creation_1h_tokens
"integer", // image_output_tokens
"numeric", // image_output_cost
"numeric", // input_cost "numeric", // input_cost
"numeric", // output_cost "numeric", // output_cost
"numeric", // cache_creation_cost "numeric", // cache_creation_cost
@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // inbound_endpoint "text", // inbound_endpoint
"text", // upstream_endpoint "text", // upstream_endpoint
"boolean", // cache_ttl_overridden "boolean", // cache_ttl_overridden
"bigint", // channel_id
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"timestamptz", // created_at "timestamptz", // created_at
} }
@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*39) args := make([]any, 0, len(keys)*47)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) )
SELECT SELECT
@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*40) args := make([]any, 0, len(preparedList)*46)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) )
SELECT SELECT
@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens, cache_read_tokens,
cache_creation_5m_tokens, cache_creation_5m_tokens,
cache_creation_1h_tokens, cache_creation_1h_tokens,
image_output_tokens,
image_output_cost,
input_cost, input_cost,
output_cost, output_cost,
cache_creation_cost, cache_creation_cost,
@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
billing_mode,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
$8, $9, $8, $9,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15, $16, $17,
$16, $17, $18, $19, $20, $21, $18, $19, $20, $21, $22, $23,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
channelID := nullInt64(log.ChannelID)
modelMappingChain := nullString(log.ModelMappingChain)
billingTier := nullString(log.BillingTier)
billingMode := nullString(log.BillingMode)
requestedModel := strings.TrimSpace(log.RequestedModel) requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" { if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model) requestedModel = strings.TrimSpace(log.Model)
@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
inboundEndpoint, inboundEndpoint,
upstreamEndpoint, upstreamEndpoint,
log.CacheTTLOverridden, log.CacheTTLOverridden,
channelID,
modelMappingChain,
billingTier,
billingMode,
createdAt, createdAt,
}, },
} }
@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
conditions := make([]string, 0, 8) conditions := make([]string, 0, 9)
args := make([]any, 0, 8) args := make([]any, 0, 9)
if filters.UserID > 0 { if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType)) args = append(args, int16(*filters.BillingType))
} }
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil { if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime) args = append(args, *filters.StartTime)
@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint) args = append(args, dim.Endpoint)
} }
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 { if limit > 0 {
@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType)) args = append(args, int16(*filters.BillingType))
} }
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
if filters.StartTime != nil { if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime) args = append(args, *filters.StartTime)
@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int cacheReadTokens int
cacheCreation5m int cacheCreation5m int
cacheCreation1h int cacheCreation1h int
imageOutputTokens int
imageOutputCost float64
inputCost float64 inputCost float64
outputCost float64 outputCost float64
cacheCreationCost float64 cacheCreationCost float64
@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
inboundEndpoint sql.NullString inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString upstreamEndpoint sql.NullString
cacheTTLOverridden bool cacheTTLOverridden bool
channelID sql.NullInt64
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
createdAt time.Time createdAt time.Time
) )
@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens, &cacheReadTokens,
&cacheCreation5m, &cacheCreation5m,
&cacheCreation1h, &cacheCreation1h,
&imageOutputTokens,
&imageOutputCost,
&inputCost, &inputCost,
&outputCost, &outputCost,
&cacheCreationCost, &cacheCreationCost,
@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&inboundEndpoint, &inboundEndpoint,
&upstreamEndpoint, &upstreamEndpoint,
&cacheTTLOverridden, &cacheTTLOverridden,
&channelID,
&modelMappingChain,
&billingTier,
&billingMode,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens, CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m, CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h, CacheCreation1hTokens: cacheCreation1h,
ImageOutputTokens: imageOutputTokens,
ImageOutputCost: imageOutputCost,
InputCost: inputCost, InputCost: inputCost,
OutputCost: outputCost, OutputCost: outputCost,
CacheCreationCost: cacheCreationCost, CacheCreationCost: cacheCreationCost,
@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid { if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String log.UpstreamModel = &upstreamModel.String
} }
if channelID.Valid {
value := channelID.Int64
log.ChannelID = &value
}
if modelMappingChain.Valid {
log.ModelMappingChain = &modelMappingChain.String
}
if billingTier.Valid {
log.BillingTier = &billingTier.String
}
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
return log, nil return log, nil
} }

View File

@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden, log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens, log.CacheReadTokens,
log.CacheCreation5mTokens, log.CacheCreation5mTokens,
log.CacheCreation1hTokens, log.CacheCreation1hTokens,
log.ImageOutputTokens,
log.ImageOutputCost,
log.InputCost, log.InputCost,
log.OutputCost, log.OutputCost,
log.CacheCreationCost, log.CacheCreationCost,
@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
log.CacheTTLOverridden, log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens 4, // cache_read_tokens
5, // cache_creation_5m_tokens 5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens 6, // cache_creation_1h_tokens
0, // image_output_tokens
0.0, // image_output_cost
0.1, // input_cost 0.1, // input_cost
0.2, // output_cost 0.2, // output_cost
0.3, // cache_creation_cost 0.3, // cache_creation_cost
@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0, 1.0,
sql.NullFloat64{}, sql.NullFloat64{},
@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)

View File

@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository, NewUserGroupRateRepository,
NewErrorPassthroughRepository, NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository, NewTLSFingerprintProfileRepository,
NewChannelRepository,
// Cache implementations // Cache implementations
NewGatewayCache, NewGatewayCache,

View File

@ -87,6 +87,9 @@ func RegisterAdminRoutes(
// 定时测试计划 // 定时测试计划
registerScheduledTestRoutes(admin, h) registerScheduledTestRoutes(admin, h)
// 渠道管理
registerChannelRoutes(admin, h)
} }
} }
@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
} }
} }
func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}

View File

@ -56,6 +56,7 @@ type ModelPricing struct {
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
} }
const ( const (
@ -94,16 +95,19 @@ type UsageTokens struct {
CacheReadTokens int CacheReadTokens int
CacheCreation5mTokens int CacheCreation5mTokens int
CacheCreation1hTokens int CacheCreation1hTokens int
ImageOutputTokens int
} }
// CostBreakdown 费用明细 // CostBreakdown 费用明细
type CostBreakdown struct { type CostBreakdown struct {
InputCost float64 InputCost float64
OutputCost float64 OutputCost float64
ImageOutputCost float64
CacheCreationCost float64 CacheCreationCost float64
CacheReadCost float64 CacheReadCost float64
TotalCost float64 TotalCost float64
ActualCost float64 // 应用倍率后的实际费用 ActualCost float64 // 应用倍率后的实际费用
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
} }
// BillingService 计费服务 // BillingService 计费服务
@ -357,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken,
}), nil }), nil
} }
} }
@ -371,81 +376,252 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
return nil, fmt.Errorf("pricing not found for model: %s", model) return nil, fmt.Errorf("pricing not found for model: %s", model)
} }
// CalculateCost 计算使用费用 // GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { // 仅覆盖渠道中非 nil 的价格字段nil 字段使用默认定价
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
}
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
pricing, err := s.GetModelPricing(model) pricing, err := s.GetModelPricing(model)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if channelPricing == nil {
return pricing, nil
}
if channelPricing.InputPrice != nil {
pricing.InputPricePerToken = *channelPricing.InputPrice
pricing.InputPricePerTokenPriority = *channelPricing.InputPrice
}
if channelPricing.OutputPrice != nil {
pricing.OutputPricePerToken = *channelPricing.OutputPrice
pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice
}
if channelPricing.CacheWritePrice != nil {
pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice
pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice
pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice
}
if channelPricing.CacheReadPrice != nil {
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
}
if channelPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
}
return pricing, nil
}
breakdown := &CostBreakdown{} // --- 统一计费入口 ---
inputPricePerToken := pricing.InputPricePerToken
outputPricePerToken := pricing.OutputPricePerToken // CostInput 统一计费输入
cacheReadPricePerToken := pricing.CacheReadPricePerToken type CostInput struct {
Ctx context.Context
Model string
GroupID *int64 // 用于渠道定价查找
Tokens UsageTokens
RequestCount int // 按次计费时使用
SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
RateMultiplier float64
ServiceTier string // "priority","flex","" 等
Resolver *ModelPricingResolver // 定价解析器
Resolved *ResolvedPricing // 可选:预解析的定价结果(避免重复 Resolve 调用)
}
// CalculateCostUnified 统一计费入口,支持三种计费模式。
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) {
if input.Resolver == nil {
// 无 Resolver回退到旧路径
return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil)
}
// 优先使用预解析结果,避免重复 Resolve 调用
resolved := input.Resolved
if resolved == nil {
resolved = input.Resolver.Resolve(input.Ctx, PricingInput{
Model: input.Model,
GroupID: input.GroupID,
})
}
if input.RateMultiplier <= 0 {
input.RateMultiplier = 1.0
}
var breakdown *CostBreakdown
var err error
switch resolved.Mode {
case BillingModePerRequest, BillingModeImage:
breakdown, err = s.calculatePerRequestCost(resolved, input)
default: // BillingModeToken
breakdown, err = s.calculateTokenCost(resolved, input)
}
if err == nil && breakdown != nil {
breakdown.BillingMode = string(resolved.Mode)
if breakdown.BillingMode == "" {
breakdown.BillingMode = string(BillingModeToken)
}
}
return breakdown, err
}
// calculateTokenCost 按 token 区间计费
func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
if pricing == nil {
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
}
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
// 长上下文定价仅在无区间定价时应用(区间定价已包含上下文分层)
applyLongCtx := len(resolved.Intervals) == 0
return s.computeTokenBreakdown(pricing, input.Tokens, input.RateMultiplier, input.ServiceTier, applyLongCtx), nil
}
// computeTokenBreakdown 是 token 计费的核心逻辑,由 calculateTokenCost 和 calculateCostInternal 共用。
// applyLongCtx 控制是否检查长上下文定价(区间定价已自含上下文分层,不需要额外应用)。
func (s *BillingService) computeTokenBreakdown(
pricing *ModelPricing, tokens UsageTokens,
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
inputPrice := pricing.InputPricePerToken
outputPrice := pricing.OutputPricePerToken
cacheReadPrice := pricing.CacheReadPricePerToken
tierMultiplier := 1.0 tierMultiplier := 1.0
if usePriorityServiceTierPricing(serviceTier, pricing) { if usePriorityServiceTierPricing(serviceTier, pricing) {
if pricing.InputPricePerTokenPriority > 0 { if pricing.InputPricePerTokenPriority > 0 {
inputPricePerToken = pricing.InputPricePerTokenPriority inputPrice = pricing.InputPricePerTokenPriority
} }
if pricing.OutputPricePerTokenPriority > 0 { if pricing.OutputPricePerTokenPriority > 0 {
outputPricePerToken = pricing.OutputPricePerTokenPriority outputPrice = pricing.OutputPricePerTokenPriority
} }
if pricing.CacheReadPricePerTokenPriority > 0 { if pricing.CacheReadPricePerTokenPriority > 0 {
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority cacheReadPrice = pricing.CacheReadPricePerTokenPriority
} }
} else { } else {
tierMultiplier = serviceTierCostMultiplier(serviceTier) tierMultiplier = serviceTierCostMultiplier(serviceTier)
} }
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) {
outputPricePerToken *= pricing.LongContextOutputMultiplier inputPrice *= pricing.LongContextInputMultiplier
outputPrice *= pricing.LongContextOutputMultiplier
} }
// 计算输入token费用使用per-token价格 bd := &CostBreakdown{}
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken bd.InputCost = float64(tokens.InputTokens) * inputPrice
// 计算输出token费用 // 分离图片输出 token 与文本输出 token
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
bd.OutputCost = float64(textOutputTokens) * outputPrice
// 计算缓存费用 // 图片输出 token 费用(独立费率)
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if tokens.ImageOutputTokens > 0 {
// 支持详细缓存分类的模型5分钟/1小时缓存价格为 per-token imgPrice := pricing.ImageOutputPricePerToken
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { if imgPrice == 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 imgPrice = outputPrice // 回退到常规输出价格
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
} }
} else { bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice
// 标准缓存创建价格per-token
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
} }
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken // 缓存创建费用
bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens)
bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice
if tierMultiplier != 1.0 { if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier bd.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier bd.OutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier bd.ImageOutputCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier bd.CacheCreationCost *= tierMultiplier
bd.CacheReadCost *= tierMultiplier
} }
// 计算总费用 bd.TotalCost = bd.InputCost + bd.OutputCost + bd.ImageOutputCost +
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + bd.CacheCreationCost + bd.CacheReadCost
breakdown.CacheCreationCost + breakdown.CacheReadCost bd.ActualCost = bd.TotalCost * rateMultiplier
// 应用倍率计算实际费用 return bd
if rateMultiplier <= 0 { }
rateMultiplier = 1.0
// computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。
func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 {
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
}
return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
} }
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
return breakdown, nil // calculatePerRequestCost 按次/图片计费
func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
count := input.RequestCount
if count <= 0 {
count = 1
}
var unitPrice float64
if input.SizeTier != "" {
unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier)
}
if unitPrice == 0 {
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext)
}
// 回退到默认按次价格
if unitPrice == 0 {
unitPrice = resolved.DefaultPerRequestPrice
}
totalCost := unitPrice * float64(count)
actualCost := totalCost * input.RateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}, nil
}
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil)
}
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil)
}
func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
var pricing *ModelPricing
var err error
if channelPricing != nil {
pricing, err = s.GetModelPricingWithChannel(model, channelPricing)
} else {
pricing, err = s.GetModelPricing(model)
}
if err != nil {
return nil, err
}
// 旧路径始终检查长上下文定价(无区间定价概念)
return s.computeTokenBreakdown(pricing, tokens, rateMultiplier, serviceTier, true), nil
} }
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
@ -541,6 +717,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
CacheReadTokens: inRangeCacheTokens, CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens, CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens, CacheCreation1hTokens: tokens.CacheCreation1hTokens,
ImageOutputTokens: tokens.ImageOutputTokens,
} }
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil { if err != nil {
@ -561,6 +738,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
return &CostBreakdown{ return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost, InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost, OutputCost: inRangeCost.OutputCost,
ImageOutputCost: inRangeCost.ImageOutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost, CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
@ -662,8 +840,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
actualCost := totalCost * rateMultiplier actualCost := totalCost * rateMultiplier
return &CostBreakdown{ return &CostBreakdown{
TotalCost: totalCost, TotalCost: totalCost,
ActualCost: actualCost, ActualCost: actualCost,
BillingMode: string(BillingModeImage),
} }
} }

View File

@ -0,0 +1,277 @@
package service
import (
"fmt"
"sort"
"strings"
"time"
)
// BillingMode 计费模式
type BillingMode string
const (
BillingModeToken BillingMode = "token" // 按 token 区间计费
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
)
// IsValid 检查 BillingMode 是否为合法值
func (m BillingMode) IsValid() bool {
switch m {
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
return true
}
return false
}
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
)
// Channel 渠道实体
type Channel struct {
ID int64
Name string
Description string
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time
UpdatedAt time.Time
// 关联的分组 ID 列表
GroupIDs []int64
// 模型定价列表(每条含 Platform 字段)
ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
}
// ChannelModelPricing 渠道模型定价条目
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
}
// PricingInterval 定价区间token 区间 / 按次分层 / 图片分辨率分层)
type PricingInterval struct {
ID int64
PricingID int64
MinTokens int // 区间下界(含)
MaxTokens *int // 区间上界不含nil = 无上限
TierLabel string // 层级标签(按次/图片模式1K, 2K, 4K, HD 等)
InputPrice *float64 // token 模式:每 token 输入价
OutputPrice *float64 // token 模式:每 token 输出价
CacheWritePrice *float64 // token 模式:缓存写入价
CacheReadPrice *float64 // token 模式:缓存读取价
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
SortOrder int
CreatedAt time.Time
UpdatedAt time.Time
}
// IsActive 判断渠道是否启用
func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
modelLower := strings.ToLower(model)
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
return nil
}
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
// 区间为左开右闭 (min, max]min 不含max 包含。
// 第一个区间 min=0 时0 token 不匹配任何区间(回退到默认价格)。
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
for i := range intervals {
iv := &intervals[i]
if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) {
return iv
}
}
return nil
}
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
return FindMatchingInterval(p.Intervals, totalTokens)
}
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
labelLower := strings.ToLower(label)
for i := range p.Intervals {
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
return &p.Intervals[i]
}
}
return nil
}
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
func (p ChannelModelPricing) Clone() ChannelModelPricing {
cp := p
if p.Models != nil {
cp.Models = make([]string, len(p.Models))
copy(cp.Models, p.Models)
}
if p.Intervals != nil {
cp.Intervals = make([]PricingInterval, len(p.Intervals))
copy(cp.Intervals, p.Intervals)
}
return cp
}
// Clone 返回 Channel 的深拷贝
func (c *Channel) Clone() *Channel {
if c == nil {
return nil
}
cp := *c
if c.GroupIDs != nil {
cp.GroupIDs = make([]int64, len(c.GroupIDs))
copy(cp.GroupIDs, c.GroupIDs)
}
if c.ModelPricing != nil {
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
for i := range c.ModelPricing {
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
}
}
if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
for platform, mapping := range c.ModelMapping {
inner := make(map[string]string, len(mapping))
for k, v := range mapping {
inner[k] = v
}
cp.ModelMapping[platform] = inner
}
}
return &cp
}
// ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);
// 无界区间MaxTokens=nil必须是最后一个。间隙允许回退默认价格
func ValidateIntervals(intervals []PricingInterval) error {
if len(intervals) == 0 {
return nil
}
sorted := make([]PricingInterval, len(intervals))
copy(sorted, intervals)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].MinTokens < sorted[j].MinTokens
})
for i := range sorted {
if err := validateSingleInterval(&sorted[i], i); err != nil {
return err
}
}
return validateIntervalOverlap(sorted)
}
// validateSingleInterval 校验单个区间的字段合法性
func validateSingleInterval(iv *PricingInterval, idx int) error {
if iv.MinTokens < 0 {
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
}
if iv.MaxTokens != nil {
if *iv.MaxTokens <= 0 {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
}
if *iv.MaxTokens <= iv.MinTokens {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
idx+1, *iv.MaxTokens, iv.MinTokens)
}
}
return validateIntervalPrices(iv, idx)
}
// validateIntervalPrices 校验区间内所有价格字段 >= 0
func validateIntervalPrices(iv *PricingInterval, idx int) error {
prices := []struct {
name string
val *float64
}{
{"input_price", iv.InputPrice},
{"output_price", iv.OutputPrice},
{"cache_write_price", iv.CacheWritePrice},
{"cache_read_price", iv.CacheReadPrice},
{"per_request_price", iv.PerRequestPrice},
}
for _, p := range prices {
if p.val != nil && *p.val < 0 {
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
}
}
return nil
}
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
func validateIntervalOverlap(sorted []PricingInterval) error {
for i, iv := range sorted {
// 无界区间必须是最后一个
if iv.MaxTokens == nil && i < len(sorted)-1 {
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
i+1)
}
if i == 0 {
continue
}
prev := sorted[i-1]
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
// (min, max] 语义prev 覆盖 (prev.Min, prev.Max]cur 覆盖 (cur.Min, cur.Max]
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
}
}
return nil
}
func formatMaxTokensLabel(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}

View File

@ -0,0 +1,857 @@
package service
import (
"context"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/sync/singleflight"
)
var (
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
ErrGroupAlreadyInChannel = infraerrors.Conflict(
"GROUP_ALREADY_IN_CHANNEL",
"one or more groups already belong to another channel",
)
)
// ChannelRepository 渠道数据访问接口
type ChannelRepository interface {
Create(ctx context.Context, channel *Channel) error
GetByID(ctx context.Context, id int64) (*Channel, error)
Update(ctx context.Context, channel *Channel) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error)
ListAll(ctx context.Context) ([]Channel, error)
ExistsByName(ctx context.Context, name string) (bool, error)
ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error)
// 分组关联
GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error)
SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
// 分组平台查询
GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error)
// 模型定价
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
DeleteModelPricing(ctx context.Context, id int64) error
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
}
// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突)
type channelModelKey struct {
groupID int64
platform string // 平台标识
model string // lowercase
}
// channelGroupPlatformKey 通配符定价缓存键
type channelGroupPlatformKey struct {
groupID int64
platform string
}
// wildcardPricingEntry 通配符定价条目
type wildcardPricingEntry struct {
prefix string
pricing *ChannelModelPricing
}
// wildcardMappingEntry 通配符映射条目
type wildcardMappingEntry struct {
prefix string
target string
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径CRUD 操作)
byID map[int64]*Channel
loadedAt time.Time
}
// ChannelMappingResult 渠道映射查找结果
type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID0 = 无渠道关联)
Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped"
}
// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。
// reqModel: 客户端请求的原始模型名。
// upstreamModel: 上游实际使用的模型名ForwardResult.UpstreamModel
// 返回空字符串表示无映射。
func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel string) string {
if !r.Mapped {
if upstreamModel != "" && upstreamModel != reqModel {
return reqModel + "→" + upstreamModel
}
return ""
}
if upstreamModel != "" && upstreamModel != r.MappedModel {
return reqModel + "→" + r.MappedModel + "→" + upstreamModel
}
return reqModel + "→" + r.MappedModel
}
// ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
channelMappedModel := reqModel
if r.Mapped {
channelMappedModel = r.MappedModel
}
return ChannelUsageFields{
ChannelID: r.ChannelID,
OriginalModel: reqModel,
ChannelMappedModel: channelMappedModel,
BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
}
}
const (
channelCacheTTL = 10 * time.Minute
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second
)
// ChannelService 渠道管理服务
type ChannelService struct {
repo ChannelRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
cache atomic.Value // *channelCache
cacheSF singleflight.Group
}
// NewChannelService 创建渠道服务实例
func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
s := &ChannelService{
repo: repo,
authCacheInvalidator: authCacheInvalidator,
}
return s
}
// loadCache 加载或返回缓存的渠道数据
func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil {
if time.Since(cached.loadedAt) < channelCacheTTL {
return cached, nil
}
}
result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) {
// 双重检查
if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil {
if time.Since(cached.loadedAt) < channelCacheTTL {
return cached, nil
}
}
return s.buildCache(ctx)
})
if err != nil {
return nil, err
}
cache, ok := result.(*channelCache)
if !ok {
return nil, fmt.Errorf("unexpected cache type")
}
return cache, nil
}
// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化)
func newEmptyChannelCache() *channelCache {
return &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
}
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
// 缓存 key 使用定价条目的原始平台pricing.Platform而非分组平台
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
if !isPlatformPricingMatch(platform, pricing.Platform) {
continue // 跳过非本平台的定价
}
// 使用定价条目的原始平台作为缓存 key防止跨平台同名模型冲突
pricingPlatform := pricing.Platform
gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform}
for _, model := range pricing.Models {
if strings.HasSuffix(model, "*") {
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
prefix: prefix,
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
}
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。
// 缓存 key 使用映射条目的原始平台mappingPlatform避免跨平台同名映射覆盖。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
if !ok {
continue
}
// 使用映射条目的原始平台作为缓存 key防止跨平台同名映射冲突
gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform}
for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") {
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
prefix: prefix,
target: dst,
})
} else {
key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
}
}
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
channels, err := s.repo.ListAll(dbCtx)
if err != nil {
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
}
// 收集所有 groupID批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
return nil, fmt.Errorf("get group platforms: %w", err)
}
}
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
cache.loadedAt = time.Now()
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
expandPricingToCache(cache, ch, gid, platform)
expandMappingToCache(cache, ch, gid, platform)
}
}
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
}
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
// antigravity 平台同时服务 Claudeanthropic和 Geminigemini模型
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
if groupPlatform == pricingPlatform {
return true
}
if groupPlatform == PlatformAntigravity {
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
}
return false
}
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
func matchingPlatforms(groupPlatform string) []string {
if groupPlatform == PlatformAntigravity {
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
}
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
s.cache.Store((*channelCache)(nil))
s.cacheSF.Forget("channel_cache")
// 主动重建缓存,确保 CRUD 后立即生效
if _, err := s.buildCache(context.Background()); err != nil {
slog.Warn("failed to rebuild channel cache after invalidation", "error", err)
}
}
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.pricing
}
}
return nil
}
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardMappingByGP[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.target
}
}
return ""
}
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台antigravity → anthropic → gemini
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if pricing, ok := cache.pricingByGroupModel[key]; ok {
return pricing
}
}
// 精确查找全部失败,依次尝试通配符匹配
for _, p := range matchingPlatforms(groupPlatform) {
if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil {
return pricing
}
}
return nil
}
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
return mapped
}
}
for _, p := range matchingPlatforms(groupPlatform) {
if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" {
return mapped
}
}
return ""
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx)
if err != nil {
return nil, err
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil, nil
}
return ch.Clone(), nil
}
// channelLookup 热路径公共查找结果
type channelLookup struct {
cache *channelCache
channel *Channel
platform string
}
// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。
// 返回 nil 且 err==nil 表示分组无活跃渠道err!=nil 表示缓存加载失败。
func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) {
cache, err := s.loadCache(ctx)
if err != nil {
return nil, err
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil, nil
}
return &channelLookup{
cache: cache,
channel: ch,
platform: cache.groupPlatform[groupID],
}, nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
// antigravity 分组依次尝试所有匹配平台antigravity → anthropic → gemini
// 确保跨平台同名模型各自独立匹配。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
return nil
}
if lk == nil {
return nil
}
modelLower := strings.ToLower(model)
pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower)
if pricing == nil {
return nil
}
cp := pricing.Clone()
return &cp
}
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache for mapping", "group_id", groupID, "error", err)
}
if lk == nil {
return ChannelMappingResult{MappedModel: model}
}
return resolveMapping(lk, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制。
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID)
if lk == nil {
return false
}
return checkRestricted(lk, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 返回映射结果。模型限制检查已移至调度阶段GatewayService.checkChannelPricingRestriction
// restricted 始终返回 false保留签名兼容性。
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if groupID == nil {
return ChannelMappingResult{MappedModel: model}, false
}
lk, _ := s.lookupGroupChannel(ctx, *groupID)
if lk == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return resolveMapping(lk, *groupID, model), false
}
// resolveMapping 基于已查找的渠道信息解析模型映射。
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
result := ChannelMappingResult{
MappedModel: model,
ChannelID: lk.channel.ID,
BillingModelSource: lk.channel.BillingModelSource,
}
if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceChannelMapped
}
modelLower := strings.ToLower(model)
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
result.MappedModel = mapped
result.Mapped = true
}
return result
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
// antigravity 分组依次尝试所有匹配平台的定价列表。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
}
modelLower := strings.ToLower(model)
// 使用与查找定价相同的跨平台逻辑
if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil {
return false
}
return true
}
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
func ReplaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 {
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
}
// --- CRUD ---
// Create 创建渠道
func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) {
exists, err := s.repo.ExistsByName(ctx, input.Name)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
}
// 检查分组冲突
if len(input.GroupIDs) > 0 {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
}
channel := &Channel{
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
if err := s.repo.Create(ctx, channel); err != nil {
return nil, fmt.Errorf("create channel: %w", err)
}
s.invalidateCache()
return s.repo.GetByID(ctx, channel.ID)
}
// GetByID 获取渠道详情
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
return s.repo.GetByID(ctx, id)
}
// Update 更新渠道
func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) {
channel, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
// 检查分组冲突
if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
// 先获取旧分组Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
}
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
// 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil {
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
for _, gid := range oldGroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
return s.repo.GetByID(ctx, id)
}
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
}
if err := s.repo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete channel: %w", err)
}
s.invalidateCache()
if s.authCacheInvalidator != nil {
for _, gid := range groupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
return nil
}
// List 获取渠道列表
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
return s.repo.List(ctx, params, status, search)
}
// modelEntry 表示一个模型模式条目(用于冲突检测)
type modelEntry struct {
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4"
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
wildcard bool
}
// conflictsBetween 检查两个模型模式是否冲突
func conflictsBetween(a, b modelEntry) bool {
switch {
case !a.wildcard && !b.wildcard:
return a.prefix == b.prefix
case a.wildcard && !b.wildcard:
return strings.HasPrefix(b.prefix, a.prefix)
case !a.wildcard && b.wildcard:
return strings.HasPrefix(a.prefix, b.prefix)
default:
return strings.HasPrefix(a.prefix, b.prefix) ||
strings.HasPrefix(b.prefix, a.prefix)
}
}
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
lower := strings.ToLower(pattern)
isWild := strings.HasSuffix(lower, "*")
prefix := lower
if isWild {
prefix = strings.TrimSuffix(lower, "*")
}
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
byPlatform := make(map[string][]modelEntry)
for _, p := range pricingList {
for _, model := range p.Models {
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
}
}
for platform, entries := range byPlatform {
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
return err
}
}
return nil
}
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
for platform, platformMapping := range mapping {
entries := make([]modelEntry, 0, len(platformMapping))
for src := range platformMapping {
entries = append(entries, toModelEntry(src))
}
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
return err
}
}
return nil
}
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
for _, pricing := range pricingList {
if err := ValidateIntervals(pricing.Intervals); err != nil {
return infraerrors.BadRequest(
"INVALID_PRICING_INTERVALS",
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
pricing.Platform, pricing.Models, err),
)
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if conflictsBetween(entries[i], entries[j]) {
return infraerrors.BadRequest(errCode,
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
label, entries[i].pattern, entries[j].pattern, platform))
}
}
}
return nil
}
// --- Input types ---
// CreateChannelInput 创建渠道输入
type CreateChannelInput struct {
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
}
// UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct {
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,435 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetModelPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
},
}
tests := []struct {
name string
model string
wantID int64
wantNil bool
}{
{"exact match", "claude-sonnet-4", 1, false},
{"case insensitive", "Claude-Sonnet-4", 1, false},
{"not found", "gemini-3.1-pro", 0, true},
{"wildcard pattern not matched", "claude-opus-4-20250514", 0, true},
{"per_request model", "gpt-5.1", 3, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ch.GetModelPricing(tt.model)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.Equal(t, tt.wantID, result.ID)
})
}
}
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
},
}
result := ch.GetModelPricing("claude-sonnet-4")
require.NotNil(t, result)
// Modify the returned copy's slice — original should be unchanged
result.Models = append(result.Models, "hacked")
// Original should be unchanged
require.Equal(t, 1, len(ch.ModelPricing[0].Models))
}
func TestGetModelPricing_EmptyPricing(t *testing.T) {
ch := &Channel{ModelPricing: nil}
require.Nil(t, ch.GetModelPricing("any-model"))
ch2 := &Channel{ModelPricing: []ChannelModelPricing{}}
require.Nil(t, ch2.GetModelPricing("any-model"))
}
func TestGetIntervalForContext(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
}
tests := []struct {
name string
tokens int
wantPrice *float64
wantNil bool
}{
{"first interval", 50000, testPtrFloat64(1e-6), false},
// (min, max] — 128000 在第一个区间的 max包含所以匹配第一个
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
// 128001 > 128000匹配第二个区间
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
// (0, max] — 0 不匹配任何区间(左开)
{"zero tokens: no match", 0, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := p.GetIntervalForContext(tt.tokens)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12)
})
}
}
func TestGetIntervalForContext_NoMatch(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
},
}
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
}
func TestGetIntervalForContext_Empty(t *testing.T) {
p := &ChannelModelPricing{Intervals: nil}
require.Nil(t, p.GetIntervalForContext(1000))
}
func TestGetTierByLabel(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
},
}
tests := []struct {
name string
label string
wantNil bool
want float64
}{
{"exact match", "1K", false, 0.04},
{"case insensitive", "hd", false, 0.12},
{"not found", "4K", true, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := p.GetTierByLabel(tt.label)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12)
})
}
}
func TestGetTierByLabel_Empty(t *testing.T) {
p := &ChannelModelPricing{Intervals: nil}
require.Nil(t, p.GetTierByLabel("1K"))
}
func TestChannelClone(t *testing.T) {
original := &Channel{
ID: 1,
Name: "test",
GroupIDs: []int64{10, 20},
ModelPricing: []ChannelModelPricing{
{
ID: 100,
Models: []string{"model-a"},
InputPrice: testPtrFloat64(5e-6),
},
},
}
cloned := original.Clone()
require.NotNil(t, cloned)
require.Equal(t, original.ID, cloned.ID)
require.Equal(t, original.Name, cloned.Name)
// Modify clone slices — original should not change
cloned.GroupIDs[0] = 999
require.Equal(t, int64(10), original.GroupIDs[0])
cloned.ModelPricing[0].Models[0] = "hacked"
require.Equal(t, "model-a", original.ModelPricing[0].Models[0])
}
func TestChannelClone_Nil(t *testing.T) {
var ch *Channel
require.Nil(t, ch.Clone())
}
func TestChannelModelPricingClone(t *testing.T) {
original := ChannelModelPricing{
Models: []string{"a", "b"},
Intervals: []PricingInterval{
{MinTokens: 0, TierLabel: "tier1"},
},
}
cloned := original.Clone()
// Modify clone slices — original unchanged
cloned.Models[0] = "hacked"
require.Equal(t, "a", original.Models[0])
cloned.Intervals[0].TierLabel = "hacked"
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
}
// --- BillingMode.IsValid ---
func TestBillingModeIsValid(t *testing.T) {
tests := []struct {
name string
mode BillingMode
want bool
}{
{"token", BillingModeToken, true},
{"per_request", BillingModePerRequest, true},
{"image", BillingModeImage, true},
{"empty", BillingMode(""), true},
{"unknown", BillingMode("unknown"), false},
{"random", BillingMode("xyz"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.mode.IsValid())
})
}
}
// --- Channel.IsActive ---
func TestChannelIsActive(t *testing.T) {
tests := []struct {
name string
status string
want bool
}{
{"active", StatusActive, true},
{"disabled", "disabled", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := &Channel{Status: tt.status}
require.Equal(t, tt.want, ch.IsActive())
})
}
}
// --- ChannelModelPricing.Clone edge cases ---
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
t.Run("nil models", func(t *testing.T) {
original := ChannelModelPricing{Models: nil}
cloned := original.Clone()
require.Nil(t, cloned.Models)
})
t.Run("nil intervals", func(t *testing.T) {
original := ChannelModelPricing{Intervals: nil}
cloned := original.Clone()
require.Nil(t, cloned.Intervals)
})
t.Run("empty models", func(t *testing.T) {
original := ChannelModelPricing{Models: []string{}}
cloned := original.Clone()
require.NotNil(t, cloned.Models)
require.Empty(t, cloned.Models)
})
}
// --- Channel.Clone edge cases ---
func TestChannelClone_EdgeCases(t *testing.T) {
t.Run("nil model mapping", func(t *testing.T) {
original := &Channel{ID: 1, ModelMapping: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelMapping)
})
t.Run("nil model pricing", func(t *testing.T) {
original := &Channel{ID: 1, ModelPricing: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelPricing)
})
t.Run("deep copy model mapping", func(t *testing.T) {
original := &Channel{
ID: 1,
ModelMapping: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
}
cloned := original.Clone()
// Modify the cloned nested map
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
// Original must remain unchanged
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
})
}
// --- ValidateIntervals ---
func TestValidateIntervals_Empty(t *testing.T) {
require.NoError(t, ValidateIntervals(nil))
require.NoError(t, ValidateIntervals([]PricingInterval{}))
}
func TestValidateIntervals_ValidIntervals(t *testing.T) {
tests := []struct {
name string
intervals []PricingInterval
}{
{
name: "single bounded interval",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
},
{
name: "two intervals with gap",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
},
{
name: "two contiguous intervals",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
},
{
name: "unsorted input (auto-sorted by validator)",
intervals: []PricingInterval{
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
},
{
name: "single unbounded interval",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.NoError(t, ValidateIntervals(tt.intervals))
})
}
}
func TestValidateIntervals_NegativeMinTokens(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "min_tokens")
require.Contains(t, err.Error(), ">= 0")
}
func TestValidateIntervals_MaxTokensZero(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> 0")
}
func TestValidateIntervals_MaxLessThanMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
}
func TestValidateIntervals_MaxEqualsMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
}
func TestValidateIntervals_NegativePrice(t *testing.T) {
negPrice := -0.01
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "input_price")
require.Contains(t, err.Error(), ">= 0")
}
func TestValidateIntervals_OverlappingIntervals(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "overlap")
}
func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last")
}

View File

@ -0,0 +1,130 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(1), account.ID)
}
func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
}

View File

@ -0,0 +1,293 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// --- billingModelForRestriction ---
func TestBillingModelForRestriction_Requested(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-5", got)
}
func TestBillingModelForRestriction_ChannelMapped(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestBillingModelForRestriction_Upstream(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "", got, "upstream should return empty (per-account check needed)")
}
func TestBillingModelForRestriction_Empty(t *testing.T) {
t.Parallel()
got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped")
}
// --- resolveAccountUpstreamModel ---
func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
// Antigravity 平台使用 DefaultAntigravityModelMapping
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
got := resolveAccountUpstreamModel(account, "totally-unknown-model")
require.Equal(t, "", got, "unsupported model should return empty")
}
func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAnthropic,
}
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough")
}
// --- checkChannelPricingRestriction ---
func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) {
t.Parallel()
svc := &GatewayService{}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, ""))
}
func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6但定价列表只有 claude-opus-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6定价列表包含 claude-sonnet-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) {
t.Parallel()
// billing_model_source=requested定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model claude-sonnet-4-5 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) {
t.Parallel()
// upstream 模式:预检查始终跳过(返回 false需逐账号检查
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"),
"upstream mode should skip pre-check (per-account check needed)")
}
func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: false, // 未开启模型限制
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"RestrictModels=false → always allowed")
}
func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) {
t.Parallel()
// 分组没有关联渠道
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil },
}
channelSvc := newTestChannelService(repo)
svc := &GatewayService{channelService: channelSvc}
gid := int64(999)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"no channel for group → allowed")
}
// --- isUpstreamModelRestrictedByChannel ---
func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6
// 但定价列表只有 claude-opus-4-6
require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 NOT in pricing → restricted")
}
func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"),
"unmappable model → upstream model empty → not restricted (account filter handles this)")
}

View File

@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
modelsListCacheTTL: time.Minute, modelsListCacheTTL: time.Minute,
} }
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77))
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)

View File

@ -2031,7 +2031,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service concurrencyService: nil, // No concurrency service
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2084,7 +2084,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // legacy path concurrencyService: nil, // legacy path
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2116,7 +2116,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2148,7 +2148,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
} }
excludedIDs := map[int64]struct{}{1: {}} excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2182,7 +2182,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2218,7 +2218,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2259,7 +2259,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2287,7 +2287,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.ErrorIs(t, err, ErrNoAvailableAccounts) require.ErrorIs(t, err, ErrNoAvailableAccounts)
@ -2319,7 +2319,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2352,7 +2352,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2390,7 +2390,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.WaitPlan) require.NotNil(t, result.WaitPlan)
@ -2426,7 +2426,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2485,7 +2485,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.WaitPlan) require.NotNil(t, result.WaitPlan)
@ -2539,7 +2539,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2593,7 +2593,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2651,7 +2651,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2709,7 +2709,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.WaitPlan) require.NotNil(t, result.WaitPlan)
@ -2767,7 +2767,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2804,7 +2804,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.WaitPlan) require.NotNil(t, result.WaitPlan)
@ -2856,7 +2856,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2934,7 +2934,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
} }
excluded := map[int64]struct{}{1: {}} excluded := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -2988,7 +2988,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
@ -3021,7 +3021,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.ErrorIs(t, err, ErrClaudeCodeOnly) require.ErrorIs(t, err, ErrClaudeCodeOnly)
@ -3059,7 +3059,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.WaitPlan) require.NotNil(t, result.WaitPlan)
@ -3097,7 +3097,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "") result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "", int64(0))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)

View File

@ -41,6 +41,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil, nil,
nil, nil,
nil, nil,
nil,
nil,
) )
} }

View File

@ -60,6 +60,13 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info" claudeMimicDebugInfoKey = "claude_mimic_debug_info"
) )
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键 // ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{} type forceCacheBillingKeyType struct{}
@ -483,6 +490,7 @@ type ClaudeUsage struct {
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
CacheCreation5mTokens int // 5分钟缓存创建token来自嵌套 cache_creation 对象) CacheCreation5mTokens int // 5分钟缓存创建token来自嵌套 cache_creation 对象)
CacheCreation1hTokens int // 1小时缓存创建token来自嵌套 cache_creation 对象) CacheCreation1hTokens int // 1小时缓存创建token来自嵌套 cache_creation 对象)
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// ForwardResult 转发结果 // ForwardResult 转发结果
@ -568,6 +576,8 @@ type GatewayService struct {
responseHeaderFilter *responseheaders.CompiledHeaderFilter responseHeaderFilter *responseheaders.CompiledHeaderFilter
debugModelRouting atomic.Bool debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool debugClaudeMimic atomic.Bool
channelService *ChannelService
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService tlsFPProfileService *TLSFingerprintProfileService
} }
@ -597,6 +607,8 @@ func NewGatewayService(
digestStore *DigestSessionStore, digestStore *DigestSessionStore,
settingService *SettingService, settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService, tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
) *GatewayService { ) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg)
@ -629,6 +641,8 @@ func NewGatewayService(
modelsListCacheTTL: modelsListTTL, modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService, tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
} }
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo, userGroupRateRepo,
@ -866,17 +880,7 @@ type anthropicMetadataPayload struct {
// replaceModelInBody 替换请求体中的model字段 // replaceModelInBody 替换请求体中的model字段
// 优先使用定点修改,尽量保持客户端原始字段顺序。 // 优先使用定点修改,尽量保持客户端原始字段顺序。
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 { return ReplaceModelInBody(body, newModel)
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
} }
type claudeOAuthNormalizeOptions struct { type claudeOAuthNormalizeOptions struct {
@ -1186,6 +1190,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
platform = PlatformAnthropic platform = PlatformAnthropic
} }
// Claude Code 限制可能已将 groupID 解析为 fallback group
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度 // 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
@ -1198,8 +1211,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash // 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数 // 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs)) excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs { for id := range excludedIDs {
@ -1220,6 +1235,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
ctx = s.withGroupContext(ctx, group) ctx = s.withGroupContext(ctx, group)
// Claude Code 限制可能已将 groupID 解析为 fallback group
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
var stickyAccountID int64 var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch stickyAccountID = prefetch
@ -2945,6 +2969,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持) // 3. 按优先级+最久未用选择(考虑模型支持)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
// 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
@ -2965,6 +2992,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue continue
} }
@ -3197,6 +3227,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
@ -3221,6 +3253,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue continue
} }
@ -7410,6 +7445,8 @@ type RecordUsageInput struct {
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额 APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
} }
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
@ -7439,6 +7476,18 @@ type postUsageBillingParams struct {
APIKeyService APIKeyQuotaUpdater APIKeyService APIKeyQuotaUpdater
} }
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑: // postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费 // - 订阅/余额扣费
// - API Key 配额更新 // - API Key 配额更新
@ -7468,21 +7517,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
} }
// 2. API Key 配额 // 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 3. API Key 限速用量 // 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率) // 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
@ -7564,13 +7613,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.BalanceCost = p.Cost.ActualCost cmd.BalanceCost = p.Cost.ActualCost
} }
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { if p.shouldDeductAPIKeyQuota() {
cmd.APIKeyQuotaCost = p.Cost.ActualCost cmd.APIKeyQuotaCost = p.Cost.ActualCost
} }
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { if p.shouldUpdateRateLimits() {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost cmd.APIKeyRateLimitCost = p.Cost.ActualCost
} }
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { if p.shouldUpdateAccountQuota() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
} }
@ -7694,191 +7743,41 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
} }
} }
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct {
// Claude Max 策略所需的 ParsedRequest可选仅 Claude 路径传入)
ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支image/video/prompt
// - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
LongContextThreshold int
LongContextMultiplier float64
}
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result return s.recordUsageCore(ctx, &recordUsageCoreInput{
apiKey := input.APIKey Result: input.Result,
user := input.User APIKey: input.APIKey,
account := input.Account User: input.User,
subscription := input.Subscription Account: input.Account,
Subscription: input.Subscription,
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens InboundEndpoint: input.InboundEndpoint,
// 用于粘性会话切换时的特殊计费处理 UpstreamEndpoint: input.UpstreamEndpoint,
if input.ForceCacheBilling && result.Usage.InputTokens > 0 { UserAgent: input.UserAgent,
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", IPAddress: input.IPAddress,
result.Usage.InputTokens, account.ID) RequestPayloadHash: input.RequestPayloadHash,
result.Usage.CacheReadInputTokens += result.Usage.InputTokens ForceCacheBilling: input.ForceCacheBilling,
result.Usage.InputTokens = 0 APIKeyService: input.APIKeyService,
} ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
// Cache TTL Override: 确保计费时 token 分类与账号设置一致 EnableClaudePath: true,
cacheTTLOverridden := false })
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil {
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
} else if result.MediaType == "prompt" {
cost = &CostBreakdown{}
} else if result.ImageCount > 0 {
// 图片生成计费
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
var mediaType *string
if strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil
} }
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) // RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
@ -7897,10 +7796,55 @@ type RecordUsageLongContextInput struct {
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0 LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
} }
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
return s.recordUsageCore(ctx, &recordUsageCoreInput{
Result: input.Result,
APIKey: input.APIKey,
User: input.User,
Account: input.Account,
Subscription: input.Subscription,
InboundEndpoint: input.InboundEndpoint,
UpstreamEndpoint: input.UpstreamEndpoint,
UserAgent: input.UserAgent,
IPAddress: input.IPAddress,
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
LongContextThreshold: input.LongContextThreshold,
LongContextMultiplier: input.LongContextMultiplier,
})
}
// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。
type recordUsageCoreInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
InboundEndpoint string
UpstreamEndpoint string
UserAgent string
IPAddress string
RequestPayloadHash string
ForceCacheBilling bool
APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
}
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支image/video/prompt
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result result := input.Result
apiKey := input.APIKey apiKey := input.APIKey
user := input.User user := input.User
@ -7933,38 +7877,23 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
} }
var cost *CostBreakdown // 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
// 根据请求类型选择计费方式 billingModel = input.ChannelMappedModel
if result.ImageCount > 0 {
// 图片生成计费
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
} }
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
// 确定 RequestedModel渠道映射前的原始模型
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
// 计算费用
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式 // 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
@ -7974,12 +7903,214 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} }
// 创建使用日志 // 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
requestID := usageLog.RequestID
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil
}
// calculateRecordUsageCost 根据请求类型和选项计算费用。
func (s *GatewayService) calculateRecordUsageCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
}
// Token 计费
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
if s.resolver == nil || apiKey.Group == nil {
return nil
}
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
return resolved
}
return nil
}
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
func (s *GatewayService) calculateImageCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
gid := apiKey.Group.ID
cost, err := s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
}
return cost
}
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
}
// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。
func (s *GatewayService) calculateTokenCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
var cost *CostBreakdown
var err error
// 优先尝试渠道定价 → CalculateCostUnified
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
})
} else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(
billingModel, tokens, multiplier,
opts.LongContextThreshold, opts.LongContextMultiplier,
)
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
}
return cost
}
// buildRecordUsageLog 构建使用日志并设置计费模式。
func (s *GatewayService) buildRecordUsageLog(
ctx context.Context,
input *recordUsageCoreInput,
result *ForwardResult,
apiKey *APIKey,
user *User,
account *Account,
subscription *UserSubscription,
requestedModel string,
multiplier float64,
accountRateMultiplier float64,
billingType int8,
cacheTTLOverridden bool,
cost *CostBreakdown,
opts *recordUsageOpts,
) *UsageLog {
durationMs := int(result.Duration.Milliseconds())
requestID := resolveUsageBillingRequestID(ctx, result.RequestID) requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
@ -7987,7 +8118,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model, RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@ -7998,72 +8129,170 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost, ImageOutputTokens: result.Usage.ImageOutputTokens,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier, RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier, AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType, BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost),
Stream: result.Stream, Stream: result.Stream,
DurationMs: &durationMs, DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
GroupID: apiKey.GroupID,
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if cost != nil {
// 添加 UserAgent usageLog.InputCost = cost.InputCost
if input.UserAgent != "" { usageLog.OutputCost = cost.OutputCost
usageLog.UserAgent = &input.UserAgent usageLog.ImageOutputCost = cost.ImageOutputCost
usageLog.CacheCreationCost = cost.CacheCreationCost
usageLog.CacheReadCost = cost.CacheReadCost
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
} }
// 添加 IPAddress return usageLog
if input.IPAddress != "" { }
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联 // resolveBillingMode 根据计费结果和请求类型确定计费模式。
if apiKey.GroupID != nil { // Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
usageLog.GroupID = apiKey.GroupID func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
} isSoraMedia := opts.EnableClaudePath &&
if subscription != nil { (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
usageLog.SubscriptionID = &subscription.ID if isSoraMedia {
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }
var mode string
billingErr := func() error { switch {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ case cost != nil && cost.BillingMode != "":
Cost: cost, mode = cost.BillingMode
User: user, case result.ImageCount > 0:
APIKey: apiKey, mode = string(BillingModeImage)
Account: account, default:
Subscription: subscription, mode = string(BillingModeToken)
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
} }
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return &mode
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil return nil
} }
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
}
return nil
}
// ResolveChannelMapping 委托渠道服务解析模型映射
func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel)
}
// IsModelRestricted 检查模型是否被渠道限制
func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段checkChannelPricingRestrictionrestricted 始终返回 false。
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
// 供调度阶段预检查requested / channel_mapped
// upstream 需逐账号检查,此处返回 false。
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
// upstream 返回空(需逐账号检查)。
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
switch source {
case BillingModelSourceRequested:
return requestedModel
case BillingModelSourceUpstream:
return ""
case BillingModelSourceChannelMapped:
return channelMappedModel
default:
return channelMappedModel
}
}
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
if account.Platform == PlatformAntigravity {
return mapAntigravityModel(account, requestedModel)
}
return account.GetMappedModel(requestedModel)
}
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {

View File

@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage {
cand := int(usage.Get("candidatesTokenCount").Int()) cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int()) cached := int(usage.Get("cachedContentTokenCount").Int())
thoughts := int(usage.Get("thoughtsTokenCount").Int()) thoughts := int(usage.Get("thoughtsTokenCount").Int())
// 从 candidatesTokensDetails 提取 IMAGE 模态 token 数
imageTokens := 0
candidateDetails := usage.Get("candidatesTokensDetails")
if candidateDetails.Exists() {
candidateDetails.ForEach(func(_, detail gjson.Result) bool {
if detail.Get("modality").String() == "IMAGE" {
imageTokens = int(detail.Get("tokenCount").Int())
return false
}
return true
})
}
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount // 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去 // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
return &ClaudeUsage{ return &ClaudeUsage{
InputTokens: prompt - cached, InputTokens: prompt - cached,
OutputTokens: cand + thoughts, OutputTokens: cand + thoughts,
CacheReadInputTokens: cached, CacheReadInputTokens: cached,
ImageOutputTokens: imageTokens,
} }
} }

View File

@ -0,0 +1,231 @@
package service
import (
"context"
"log/slog"
)
// PricingSource 定价来源标识
const (
PricingSourceChannel = "channel"
PricingSourceLiteLLM = "litellm"
PricingSourceFallback = "fallback"
)
// ResolvedPricing 统一定价解析结果
type ResolvedPricing struct {
// Mode 计费模式
Mode BillingMode
// Token 模式:基础定价(来自 LiteLLM 或 fallback
BasePricing *ModelPricing
// Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段)
Intervals []PricingInterval
// 按次/图片模式:分层定价
RequestTiers []PricingInterval
// 按次/图片模式:默认价格(未命中层级时使用)
DefaultPerRequestPrice float64
// 来源标识
Source string // "channel", "litellm", "fallback"
// 是否支持缓存细分
SupportsCacheBreakdown bool
}
// ModelPricingResolver 统一模型定价解析器。
// 解析链Channel → LiteLLM → Fallback。
type ModelPricingResolver struct {
channelService *ChannelService
billingService *BillingService
}
// NewModelPricingResolver 创建定价解析器实例
func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver {
return &ModelPricingResolver{
channelService: channelService,
billingService: billingService,
}
}
// PricingInput 定价解析输入
type PricingInput struct {
Model string
GroupID *int64 // nil 表示不检查渠道
}
// Resolve 解析模型定价。
// 1. 获取基础定价LiteLLM → Fallback
// 2. 如果指定了 GroupID查找渠道定价并覆盖
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
// 1. 获取基础定价
basePricing, source := r.resolveBasePricing(input.Model)
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: basePricing,
Source: source,
SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown,
}
// 2. 如果有 GroupID尝试渠道覆盖
if input.GroupID != nil {
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
}
return resolved
}
// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价
func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) {
pricing, err := r.billingService.GetModelPricing(model)
if err != nil {
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
"model", model, "error", err)
return nil, PricingSourceFallback
}
return pricing, PricingSourceLiteLLM
}
// applyChannelOverrides 应用渠道定价覆盖
func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) {
chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model)
if chPricing == nil {
return
}
resolved.Source = PricingSourceChannel
resolved.Mode = chPricing.BillingMode
if resolved.Mode == "" {
resolved.Mode = BillingModeToken
}
switch resolved.Mode {
case BillingModeToken:
r.applyTokenOverrides(chPricing, resolved)
case BillingModePerRequest, BillingModeImage:
r.applyRequestTierOverrides(chPricing, resolved)
}
}
// applyTokenOverrides 应用 token 模式的渠道覆盖
func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
// 过滤掉所有价格字段都为空的无效 interval
validIntervals := filterValidIntervals(chPricing.Intervals)
// 如果有有效的区间定价,使用区间
if len(validIntervals) > 0 {
resolved.Intervals = validIntervals
return
}
// 否则用 flat 字段覆盖 BasePricing
if resolved.BasePricing == nil {
resolved.BasePricing = &ModelPricing{}
}
if chPricing.InputPrice != nil {
resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice
resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice
}
if chPricing.OutputPrice != nil {
resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice
resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice
}
if chPricing.CacheWritePrice != nil {
resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice
resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice
resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice
}
if chPricing.CacheReadPrice != nil {
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
}
if chPricing.ImageOutputPrice != nil {
resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
}
}
// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) {
resolved.RequestTiers = filterValidIntervals(chPricing.Intervals)
if chPricing.PerRequestPrice != nil {
resolved.DefaultPerRequestPrice = *chPricing.PerRequestPrice
}
}
// filterValidIntervals 过滤掉所有价格字段都为空的无效 interval。
// 前端可能创建了只有 min/max 但无价格的空 interval。
func filterValidIntervals(intervals []PricingInterval) []PricingInterval {
var valid []PricingInterval
for _, iv := range intervals {
if iv.InputPrice != nil || iv.OutputPrice != nil ||
iv.CacheWritePrice != nil || iv.CacheReadPrice != nil ||
iv.PerRequestPrice != nil {
valid = append(valid, iv)
}
}
return valid
}
// GetIntervalPricing 根据 context token 数获取区间定价。
// 如果有区间列表,找到匹配区间并构造 ModelPricing否则直接返回 BasePricing。
func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing {
if len(resolved.Intervals) == 0 {
return resolved.BasePricing
}
iv := FindMatchingInterval(resolved.Intervals, totalContextTokens)
if iv == nil {
return resolved.BasePricing
}
return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown)
}
// intervalToModelPricing 将区间定价转换为 ModelPricing
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing {
pricing := &ModelPricing{
SupportsCacheBreakdown: supportsCacheBreakdown,
}
if iv.InputPrice != nil {
pricing.InputPricePerToken = *iv.InputPrice
pricing.InputPricePerTokenPriority = *iv.InputPrice
}
if iv.OutputPrice != nil {
pricing.OutputPricePerToken = *iv.OutputPrice
pricing.OutputPricePerTokenPriority = *iv.OutputPrice
}
if iv.CacheWritePrice != nil {
pricing.CacheCreationPricePerToken = *iv.CacheWritePrice
pricing.CacheCreation5mPrice = *iv.CacheWritePrice
pricing.CacheCreation1hPrice = *iv.CacheWritePrice
}
if iv.CacheReadPrice != nil {
pricing.CacheReadPricePerToken = *iv.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice
}
return pricing
}
// GetRequestTierPrice 根据层级标签获取按次价格
func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 {
for _, tier := range resolved.RequestTiers {
if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil {
return *tier.PerRequestPrice
}
}
return 0
}
// GetRequestTierPriceByContext 根据 context token 数获取按次价格
func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 {
iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens)
if iv != nil && iv.PerRequestPrice != nil {
return *iv.PerRequestPrice
}
return 0
}

View File

@ -0,0 +1,663 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func newTestBillingServiceForResolver() *BillingService {
bs := &BillingService{
fallbackPrices: make(map[string]*ModelPricing),
}
bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
CacheCreationPricePerToken: 3.75e-6,
CacheReadPricePerToken: 0.3e-6,
SupportsCacheBreakdown: false,
}
return bs
}
func TestResolve_NoGroupID(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: nil,
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeToken, resolved.Mode)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
// BillingService.GetModelPricing uses fallback internally, but resolveBasePricing
// reports "litellm" when GetModelPricing succeeds (regardless of internal source)
require.Equal(t, "litellm", resolved.Source)
}
func TestResolve_UnknownModel(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "unknown-model-xyz",
GroupID: nil,
})
require.NotNil(t, resolved)
require.Nil(t, resolved.BasePricing)
// Unknown model: GetModelPricing returns error, source is "fallback"
require.Equal(t, "fallback", resolved.Source)
}
func TestGetIntervalPricing_NoIntervals(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
basePricing := &ModelPricing{InputPricePerToken: 5e-6}
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: basePricing,
Intervals: nil,
}
result := r.GetIntervalPricing(resolved, 50000)
require.Equal(t, basePricing, result)
}
func TestGetIntervalPricing_MatchesInterval(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: &ModelPricing{InputPricePerToken: 5e-6},
SupportsCacheBreakdown: true,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)},
},
}
result := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, result)
require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12)
require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12)
require.True(t, result.SupportsCacheBreakdown)
result2 := r.GetIntervalPricing(resolved, 200000)
require.NotNil(t, result2)
require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12)
}
func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
basePricing := &ModelPricing{InputPricePerToken: 99e-6}
resolved := &ResolvedPricing{
Mode: BillingModeToken,
BasePricing: basePricing,
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)},
},
}
result := r.GetIntervalPricing(resolved, 5000)
require.Equal(t, basePricing, result)
}
func TestGetRequestTierPrice(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
},
}
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
}
func TestGetRequestTierPriceByContext(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}
require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
}
func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: nil},
},
}
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
}
// ===========================================================================
// Channel override tests — exercises applyChannelOverrides via Resolve
// ===========================================================================
// helper: creates a resolver wired to a ChannelService that returns the given
// channel (active, groupID=100, platform=anthropic) with the specified pricing.
func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver {
t.Helper()
const groupID = 100
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return []Channel{{
ID: 1,
Name: "test-channel",
Status: StatusActive,
GroupIDs: []int64{groupID},
ModelPricing: pricing,
}}, nil
},
getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) {
return map[int64]string{groupID: "anthropic"}, nil
},
}
cs := NewChannelService(repo, nil)
bs := newTestBillingServiceForResolver()
return NewModelPricingResolver(cs, bs)
}
// groupIDPtr returns a pointer to groupID 100 (the test constant).
func groupIDPtr() *int64 { v := int64(100); return &v }
// ---------------------------------------------------------------------------
// 1. Token mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(50e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeToken, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12)
}
func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) {
// Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback).
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(20e-6),
// OutputPrice intentionally nil
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
require.NotNil(t, resolved.BasePricing)
// InputPrice overridden by channel
require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
// OutputPrice kept from base (fallback: 15e-6)
require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
}
func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
require.Len(t, resolved.Intervals, 2)
// GetIntervalPricing should use channel intervals
iv := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, iv)
require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12)
require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12)
iv2 := r.GetIntervalPricing(resolved, 200000)
require.NotNil(t, iv2)
require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12)
require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12)
}
func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) {
// Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"unknown-model-xyz"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(7e-6),
OutputPrice: testPtrFloat64(21e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "unknown-model-xyz",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, "channel", resolved.Source)
// BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12)
}
// ---------------------------------------------------------------------------
// 2. Per-request mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_PerRequest(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModePerRequest, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 2)
// Verify tier lookups
require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12)
require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12)
}
func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) {
// PerRequestPrice nil → DefaultPerRequestPrice stays 0.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModePerRequest,
// PerRequestPrice intentionally nil
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModePerRequest, resolved.Mode)
require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 1)
}
// ---------------------------------------------------------------------------
// 3. Image mode overrides
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_Image(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.08),
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.NotNil(t, resolved)
require.Equal(t, BillingModeImage, resolved.Mode)
require.Equal(t, "channel", resolved.Source)
require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12)
require.Len(t, resolved.RequestTiers, 3)
}
func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeImage,
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12)
require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12)
require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12)
require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found
}
// ---------------------------------------------------------------------------
// 4. Source tracking & default mode
// ---------------------------------------------------------------------------
func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(1e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.Equal(t, "channel", resolved.Source)
}
func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) {
// Channel pricing with empty BillingMode → defaults to BillingModeToken.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: "", // intentionally empty
InputPrice: testPtrFloat64(5e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
require.Equal(t, "channel", resolved.Source)
require.Equal(t, BillingModeToken, resolved.Mode)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
}
// ---------------------------------------------------------------------------
// 5. GetIntervalPricing integration after channel override
// ---------------------------------------------------------------------------
func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) {
// Channel provides intervals that override the base pricing path.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)},
{MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
// Token count 50000 matches first interval
pricing := r.GetIntervalPricing(resolved, 50000)
require.NotNil(t, pricing)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12)
// Token count 150000 matches second interval
pricing2 := r.GetIntervalPricing(resolved, 150000)
require.NotNil(t, pricing2)
require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12)
}
func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) {
// Channel intervals don't match token count → falls back to BasePricing.
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
Intervals: []PricingInterval{
// Only covers tokens > 50000
{MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})
// Token count 1000 doesn't match any interval (1000 <= 50000 minTokens)
pricing := r.GetIntervalPricing(resolved, 1000)
// Should fall back to BasePricing (from the billing service fallback)
require.NotNil(t, pricing)
require.Equal(t, resolved.BasePricing, pricing)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price
}
// ===========================================================================
// 6. Error path tests
// ===========================================================================
func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
// When ListAll returns an error, the ChannelService cache build fails.
// Resolve should gracefully fall back to base pricing without panicking.
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) {
return nil, errors.New("database unavailable")
},
}
cs := NewChannelService(repo, nil)
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(cs, bs)
gid := int64(100)
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: &gid,
})
require.NotNil(t, resolved)
// Should NOT panic, should NOT have source "channel"
require.NotEqual(t, "channel", resolved.Source)
// Base pricing should still be present (from BillingService fallback)
require.NotNil(t, resolved.BasePricing)
require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12)
}
// ===========================================================================
// 7. GetRequestTierPriceByContext boundary tests
// ===========================================================================
func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: nil, // empty
}
price := r.GetRequestTierPriceByContext(resolved, 50000)
require.InDelta(t, 0.0, price, 1e-12)
// Also test with explicit empty slice
resolved2 := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{},
}
price2 := r.GetRequestTierPriceByContext(resolved2, 50000)
require.InDelta(t, 0.0, price2, 1e-12)
}
func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) {
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(&ChannelService{}, bs)
resolved := &ResolvedPricing{
Mode: BillingModePerRequest,
RequestTiers: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)},
{MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)},
},
}
// totalContextTokens = 128000 exactly:
// FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens
// For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval
price := r.GetRequestTierPriceByContext(resolved, 128000)
require.InDelta(t, 0.05, price, 1e-12)
// totalContextTokens = 128001 should match second interval
// For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match
// For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches
price2 := r.GetRequestTierPriceByContext(resolved, 128001)
require.InDelta(t, 0.10, price2, 1e-12)
}
// ===========================================================================
// 8. filterValidIntervals
// ===========================================================================
func TestFilterValidIntervals(t *testing.T) {
tests := []struct {
name string
intervals []PricingInterval
wantLen int
}{
{
name: "empty list",
intervals: nil,
wantLen: 0,
},
{
name: "all-nil interval filtered out",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000)},
},
wantLen: 0,
},
{
name: "interval with only InputPrice kept",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
wantLen: 1,
},
{
name: "interval with only OutputPrice kept",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), OutputPrice: testPtrFloat64(2e-6)},
},
wantLen: 1,
},
{
name: "interval with only CacheWritePrice kept",
intervals: []PricingInterval{
{MinTokens: 0, CacheWritePrice: testPtrFloat64(3e-6)},
},
wantLen: 1,
},
{
name: "interval with only CacheReadPrice kept",
intervals: []PricingInterval{
{MinTokens: 0, CacheReadPrice: testPtrFloat64(0.5e-6)},
},
wantLen: 1,
},
{
name: "interval with only PerRequestPrice kept",
intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
},
wantLen: 1,
},
{
name: "mixed valid and invalid",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil}, // all-nil → filtered out
{MinTokens: 256000, OutputPrice: testPtrFloat64(5e-6)},
},
wantLen: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterValidIntervals(tt.intervals)
require.Len(t, result, tt.wantLen)
})
}
}

View File

@ -0,0 +1,140 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
PlatformOpenAI: {"gpt-4.1": "o3-mini"},
},
}, map[int64]string{10: PlatformOpenAI}))
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true},
}},
channelService: channelSvc,
}
groupID := int64(10)
_, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
require.ErrorIs(t, err, ErrNoAvailableAccounts)
require.Contains(t, err.Error(), "channel pricing restriction")
}
func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
},
}, map[int64]string{10: PlatformOpenAI}))
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 10,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
},
},
{
ID: 2,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 20,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
},
},
}},
channelService: channelSvc,
}
groupID := int64(10)
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(2), account.ID)
}
func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
},
}, map[int64]string{10: PlatformOpenAI}))
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:sticky-session": 1},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 10,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
},
},
{
ID: 2,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 20,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
},
},
}},
channelService: channelSvc,
cache: cache,
}
groupID := int64(10)
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(2), account.ID)
require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"])
require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"])
}

View File

@ -10,8 +10,8 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_" const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) { switch normalizeCodexModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": case "gpt-5.4", "gpt-5.3-codex":
return true return true
default: default:
return false return false
@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
return "" return ""
} }
normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel)) normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" { if normalizedModel == "" {
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model)) normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
} }
if normalizedModel == "" { if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model) normalizedModel = strings.TrimSpace(req.Model)

View File

@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
// 2. Resolve model mapping early so compat prompt_cache_key injection can // 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family. // derive a stable seed from the final upstream model family.
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel) upstreamModel := normalizeCodexModel(billingModel)
promptCacheKey = strings.TrimSpace(promptCacheKey) promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false compatPromptCacheInjected := false

View File

@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
// 3. Model mapping // 3. Model mapping
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel) upstreamModel := normalizeCodexModel(billingModel)
responsesReq.Model = upstreamModel responsesReq.Model = upstreamModel
logger.L().Debug("openai messages: model mapping applied", logger.L().Debug("openai messages: model mapping applied",

View File

@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil, nil,
&DeferredService{}, &DeferredService{},
nil, nil,
nil,
nil,
) )
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo, rateRepo,

View File

@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"sort" "sort"
@ -204,6 +205,7 @@ type OpenAIUsage struct {
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
ImageOutputTokens int `json:"image_output_tokens,omitempty"`
} }
// OpenAIForwardResult represents the result of forwarding // OpenAIForwardResult represents the result of forwarding
@ -322,6 +324,8 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
@ -357,6 +361,8 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider, openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
@ -384,6 +390,8 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider, openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(), toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
} }
@ -391,6 +399,74 @@ func NewOpenAIGatewayService(
return svc return svc
} }
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段restricted 始终返回 false。
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
}
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel)
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil { if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle return s.codexSnapshotThrottle
@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
} }
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 1. 尝试粘性会话命中 // 1. 尝试粘性会话命中
// Try sticky session hit // Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号 // 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU // Select by priority + LRU
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
if selected == nil { if selected == nil {
if requestedModel != "" { if requestedModel != "" {
@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil return nil
} }
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号 // 刷新会话 TTL 并返回账号
// Refresh session TTL and return account // Refresh session TTL and return account
@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// //
// selectBestAccount selects the best account from candidates (priority + LRU). // selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account. // Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account var selected *Account
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号 // 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used // Select highest priority and least recently used
@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil { if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else { } else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
candidates = append(candidates, acc) candidates = append(candidates, acc)
} }
@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: fresh, Account: fresh,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok { if model, ok := reqBody["model"].(string); ok {
upstreamModel = resolveOpenAIUpstreamModel(model) upstreamModel = normalizeCodexModel(model)
if upstreamModel != "" && upstreamModel != model { if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, upstreamModel, account.Name, account.Type, isCodexCLI) model, upstreamModel, account.Name, account.Type, isCodexCLI)
@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
} }
// Get rate multiplier // Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
}
if apiKey.GroupID != nil && apiKey.Group != nil { if apiKey.GroupID != nil && apiKey.Group != nil {
resolver := s.userGroupRateResolver resolver := s.userGroupRateResolver
if resolver == nil { if resolver == nil {
@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
} }
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier) serviceTier = strings.TrimSpace(*result.ServiceTier)
} }
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil { if err != nil {
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
} }
@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID) requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
// 确定 RequestedModel渠道映射前的原始模型
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model, RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: actualInputTokens, InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost, ImageOutputTokens: result.Usage.ImageOutputTokens,
OutputCost: cost.OutputCost, }
CacheCreationCost: cost.CacheCreationCost, if cost != nil {
CacheReadCost: cost.CacheReadCost, usageLog.InputCost = cost.InputCost
TotalCost: cost.TotalCost, usageLog.OutputCost = cost.OutputCost
ActualCost: cost.ActualCost, usageLog.ImageOutputCost = cost.ImageOutputCost
RateMultiplier: multiplier, usageLog.CacheCreationCost = cost.CacheCreationCost
AccountRateMultiplier: &accountRateMultiplier, usageLog.CacheReadCost = cost.CacheReadCost
BillingType: billingType, usageLog.TotalCost = cost.TotalCost
Stream: result.Stream, usageLog.ActualCost = cost.ActualCost
OpenAIWSMode: result.OpenAIWSMode, }
DurationMs: &durationMs, usageLog.RateMultiplier = multiplier
FirstTokenMs: result.FirstTokenMs, usageLog.AccountRateMultiplier = &accountRateMultiplier
CreatedAt: time.Now(), usageLog.BillingType = billingType
usageLog.Stream = result.Stream
usageLog.OpenAIWSMode = result.OpenAIWSMode
usageLog.DurationMs = &durationMs
usageLog.FirstTokenMs = result.FirstTokenMs
usageLog.CreatedAt = time.Now()
// 设置渠道信息
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
} }
// 添加 UserAgent // 添加 UserAgent
if input.UserAgent != "" { if input.UserAgent != "" {

View File

@ -1,10 +1,8 @@
package service package service
import "strings" // resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// resolveOpenAIForwardModel resolves the account/group mapping result for // did not match any explicit model_mapping rule.
// OpenAI-compatible forwarding. Group-level default mapping only applies when
// the account itself did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil { if account == nil {
if defaultMappedModel != "" { if defaultMappedModel != "" {
@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
} }
return mappedModel return mappedModel
} }
func resolveOpenAIUpstreamModel(model string) string {
if isBareGPT53CodexSparkModel(model) {
return "gpt-5.3-codex-spark"
}
return normalizeCodexModel(strings.TrimSpace(model))
}
func isBareGPT53CodexSparkModel(model string) bool {
modelID := strings.TrimSpace(model)
if modelID == "" {
return false
}
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
normalized := strings.ToLower(strings.TrimSpace(modelID))
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
}

View File

@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
Credentials: map[string]any{}, Credentials: map[string]any{},
} }
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if withoutDefault != "gpt-5.1" { if withoutDefault != "gpt-5.1" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1") t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
} }
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
if withDefault != "gpt-5.4" { if withDefault != "gpt-5.4" {
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4") t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4")
} }
} }
func TestResolveOpenAIUpstreamModel(t *testing.T) { func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt-5.3": "gpt-5.3-codex",
} }
for input, expected := range cases { for input, expected := range cases {
if got := resolveOpenAIUpstreamModel(input); got != expected { if got := normalizeCodexModel(input); got != expected {
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected) t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected)
} }
} }
} }

View File

@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
} }
normalized = next normalized = next
} }
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel))
if upstreamModel != originalModel { if upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
if setErr != nil { if setErr != nil {
@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModel := "" mappedModel := ""
var mappedModelBytes []byte var mappedModelBytes []byte
if originalModel != "" { if originalModel != "" {
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel))
needModelReplace = mappedModel != "" && mappedModel != originalModel needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace { if needModelReplace {
mappedModelBytes = []byte(mappedModel) mappedModelBytes = []byte(mappedModel)

View File

@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
nil,
) )
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)

View File

@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil { if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available") return nil, fmt.Errorf("gateway service not available")
} }
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制 return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
default: default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType) return nil, fmt.Errorf("unsupported retry type: %s", reqType)
} }

View File

@ -70,7 +70,8 @@ type LiteLLMModelPricing struct {
LiteLLMProvider string `json:"litellm_provider"` LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"` Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格
} }
// PricingRemoteClient 远程价格数据获取接口 // PricingRemoteClient 远程价格数据获取接口
@ -94,6 +95,7 @@ type LiteLLMRawEntry struct {
Mode string `json:"mode"` Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
OutputCostPerImage *float64 `json:"output_cost_per_image"` OutputCostPerImage *float64 `json:"output_cost_per_image"`
OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"`
} }
// PricingService 动态价格服务 // PricingService 动态价格服务
@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.OutputCostPerImage != nil { if entry.OutputCostPerImage != nil {
pricing.OutputCostPerImage = *entry.OutputCostPerImage pricing.OutputCostPerImage = *entry.OutputCostPerImage
} }
if entry.OutputCostPerImageToken != nil {
pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken
}
result[modelName] = pricing result[modelName] = pricing
} }

View File

@ -0,0 +1,15 @@
//go:build unit
package service
// testPtrFloat64 returns a pointer to the given float64 value.
func testPtrFloat64(v float64) *float64 { return &v }
// testPtrInt returns a pointer to the given int value.
func testPtrInt(v int) *int { return &v }
// testPtrString returns a pointer to the given string value.
func testPtrString(v string) *string { return &v }
// testPtrBool returns a pointer to the given bool value.
func testPtrBool(v bool) *bool { return &v }

View File

@ -104,6 +104,14 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping. // UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is). // Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签per_request/image 模式)
BillingTier *string
// BillingMode 计费模式token/imagesora 路径为 nil
BillingMode *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
@ -126,6 +134,9 @@ type UsageLog struct {
CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"` CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"` CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
ImageOutputTokens int
ImageOutputCost float64
InputCost float64 InputCost float64
OutputCost float64 OutputCost float64
CacheCreationCost float64 CacheCreationCost float64

View File

@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
} }
return strings.TrimSpace(upstreamModel) return strings.TrimSpace(upstreamModel)
} }
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}

View File

@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet(
ProvideScheduledTestService, ProvideScheduledTestService,
ProvideScheduledTestRunnerService, ProvideScheduledTestRunnerService,
NewGroupCapacityService, NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
) )

View File

@ -0,0 +1,56 @@
-- Create channels table for managing pricing channels.
-- A channel groups multiple groups together and provides custom model pricing.
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 渠道表
CREATE TABLE IF NOT EXISTS channels (
id BIGSERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
description TEXT DEFAULT '',
status VARCHAR(20) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- 渠道名称唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
-- 渠道-分组关联表(每个分组只能属于一个渠道)
CREATE TABLE IF NOT EXISTS channel_groups (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
-- 渠道模型定价表(一条定价可绑定多个模型)
CREATE TABLE IF NOT EXISTS channel_model_pricing (
id BIGSERIAL PRIMARY KEY,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
models JSONB NOT NULL DEFAULT '[]',
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
image_output_price NUMERIC(20,8),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格USDNULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格USDNULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格NULL 表示使用默认';
COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格Gemini Image 等NULL 表示使用默认';

View File

@ -0,0 +1,67 @@
-- Extend channel_model_pricing with billing_mode and add context-interval child table.
-- Supports three billing modes: token (per-token with context intervals),
-- per_request (per-request with context-size tiers), and image (per-image).
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
-- 1. 为 channel_model_pricing 添加 billing_mode 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式token按 token 区间计费、per_request按次计费、image图片计费';
-- 2. 创建区间定价子表
CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
id BIGSERIAL PRIMARY KEY,
pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
min_tokens INT NOT NULL DEFAULT 0,
max_tokens INT,
tier_label VARCHAR(50),
input_price NUMERIC(20,12),
output_price NUMERIC(20,12),
cache_write_price NUMERIC(20,12),
cache_read_price NUMERIC(20,12),
per_request_price NUMERIC(20,12),
sort_order INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
ON channel_pricing_intervals (pricing_id);
COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界token 模式使用';
COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界不含NULL 表示无上限';
COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD';
COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
-- 仅迁移有明确定价(至少一个价格字段非 NULL的条目
INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
SELECT
cmp.id,
0,
NULL,
cmp.input_price,
cmp.output_price,
cmp.cache_write_price,
cmp.cache_read_price,
0
FROM channel_model_pricing cmp
WHERE cmp.billing_mode = 'token'
AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
AND NOT EXISTS (
SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
);
-- 4. 迁移 image_output_price 为 image 模式的区间条目
-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
-- 注意:这里不改变原条目的 billing_mode而是将 image_output_price 作为向后兼容字段保留
-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理

View File

@ -0,0 +1,5 @@
SET LOCAL lock_timeout = '5s';
SET LOCAL statement_timeout = '10min';
ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';

View File

@ -0,0 +1,7 @@
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
-- Add channel tracking fields to usage_logs
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);

View File

@ -0,0 +1,5 @@
-- Add model restriction switch to channels
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);

View File

@ -0,0 +1,21 @@
-- 086_channel_platform_pricing.sql
-- 渠道按平台维度model_pricing 加 platform 列model_mapping 改为嵌套格式
-- 1. channel_model_pricing 加 platform 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
ON channel_model_pricing (platform);
-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
UPDATE channels
SET model_mapping = jsonb_build_object('anthropic', model_mapping)
WHERE model_mapping IS NOT NULL
AND model_mapping::text NOT IN ('{}', 'null', '')
AND NOT EXISTS (
SELECT 1 FROM jsonb_each(model_mapping) AS kv
WHERE jsonb_typeof(kv.value) = 'object'
LIMIT 1
);

View File

@ -0,0 +1,2 @@
-- Add billing_mode to usage_logs (records the billing mode: token/per_request/image)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20);

View File

@ -0,0 +1,3 @@
-- Change default billing_model_source for new channels to 'channel_mapped'
-- Existing channels keep their current setting (no UPDATE on existing rows)
ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped';

View File

@ -0,0 +1,2 @@
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;

View File

@ -0,0 +1,148 @@
/**
* Admin Channels API endpoints
* Handles channel management for administrators
*/
import { apiClient } from '../client'
export type BillingMode = 'token' | 'per_request' | 'image'
export interface PricingInterval {
id?: number
min_tokens: number
max_tokens: number | null
tier_label: string
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
per_request_price: number | null
sort_order: number
}
export interface ChannelModelPricing {
id?: number
platform: string
models: string[]
billing_mode: BillingMode
input_price: number | null
output_price: number | null
cache_write_price: number | null
cache_read_price: number | null
image_output_price: number | null
per_request_price: number | null
intervals: PricingInterval[]
}
export interface Channel {
id: number
name: string
description: string
status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
group_ids: number[]
model_pricing: ChannelModelPricing[]
model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
created_at: string
updated_at: string
}
export interface CreateChannelRequest {
name: string
description?: string
group_ids?: number[]
model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string
restrict_models?: boolean
}
export interface UpdateChannelRequest {
name?: string
description?: string
status?: string
group_ids?: number[]
model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string
restrict_models?: boolean
}
interface PaginatedResponse<T> {
items: T[]
total: number
}
/**
* List channels with pagination
*/
export async function list(
page: number = 1,
pageSize: number = 20,
filters?: {
status?: string
search?: string
},
options?: { signal?: AbortSignal }
): Promise<PaginatedResponse<Channel>> {
const { data } = await apiClient.get<PaginatedResponse<Channel>>('/admin/channels', {
params: {
page,
page_size: pageSize,
...filters
},
signal: options?.signal
})
return data
}
/**
* Get channel by ID
*/
export async function getById(id: number): Promise<Channel> {
const { data } = await apiClient.get<Channel>(`/admin/channels/${id}`)
return data
}
/**
* Create a new channel
*/
export async function create(req: CreateChannelRequest): Promise<Channel> {
const { data } = await apiClient.post<Channel>('/admin/channels', req)
return data
}
/**
* Update a channel
*/
export async function update(id: number, req: UpdateChannelRequest): Promise<Channel> {
const { data } = await apiClient.put<Channel>(`/admin/channels/${id}`, req)
return data
}
/**
* Delete a channel
*/
export async function remove(id: number): Promise<void> {
await apiClient.delete(`/admin/channels/${id}`)
}
export interface ModelDefaultPricing {
found: boolean
input_price?: number // per-token price
output_price?: number
cache_write_price?: number
cache_read_price?: number
image_output_price?: number
}
export async function getModelDefaultPricing(model: string): Promise<ModelDefaultPricing> {
const { data } = await apiClient.get<ModelDefaultPricing>('/admin/channels/model-pricing', {
params: { model }
})
return data
}
const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
export default channelsAPI

View File

@ -167,6 +167,13 @@ export interface UserBreakdownParams {
endpoint?: string endpoint?: string
endpoint_type?: 'inbound' | 'upstream' | 'path' endpoint_type?: 'inbound' | 'upstream' | 'path'
limit?: number limit?: number
// Additional filter conditions
user_id?: number
api_key_id?: number
account_id?: number
request_type?: number
stream?: boolean
billing_type?: number | null
} }
export interface UserBreakdownResponse { export interface UserBreakdownResponse {

View File

@ -25,6 +25,7 @@ import apiKeysAPI from './apiKeys'
import scheduledTestsAPI from './scheduledTests' import scheduledTestsAPI from './scheduledTests'
import backupAPI from './backup' import backupAPI from './backup'
import tlsFingerprintProfileAPI from './tlsFingerprintProfile' import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
import channelsAPI from './channels'
/** /**
* Unified admin API object for convenient access * Unified admin API object for convenient access
@ -51,7 +52,8 @@ export const adminAPI = {
apiKeys: apiKeysAPI, apiKeys: apiKeysAPI,
scheduledTests: scheduledTestsAPI, scheduledTests: scheduledTestsAPI,
backup: backupAPI, backup: backupAPI,
tlsFingerprintProfiles: tlsFingerprintProfileAPI tlsFingerprintProfiles: tlsFingerprintProfileAPI,
channels: channelsAPI
} }
export { export {
@ -76,7 +78,8 @@ export {
apiKeysAPI, apiKeysAPI,
scheduledTestsAPI, scheduledTestsAPI,
backupAPI, backupAPI,
tlsFingerprintProfileAPI tlsFingerprintProfileAPI,
channelsAPI
} }
export default adminAPI export default adminAPI

View File

@ -80,6 +80,7 @@ export interface CreateUsageCleanupTaskRequest {
export interface AdminUsageQueryParams extends UsageQueryParams { export interface AdminUsageQueryParams extends UsageQueryParams {
user_id?: number user_id?: number
exact_total?: boolean exact_total?: boolean
billing_mode?: string
} }
// ==================== API Functions ==================== // ==================== API Functions ====================

View File

@ -0,0 +1,113 @@
<template>
<div class="flex items-start gap-2 rounded border p-2"
:class="isEmpty ? 'border-red-400 bg-red-50 dark:border-red-500 dark:bg-red-950/20' : 'border-gray-200 bg-white dark:border-dark-500 dark:bg-dark-700'">
<!-- Token mode: context range + prices ($/MTok) -->
<template v-if="mode === 'token'">
<div class="w-20">
<label class="text-xs text-gray-400">Min</label>
<input :value="interval.min_tokens" @input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="w-20">
<label class="text-xs text-gray-400">Max <span class="text-gray-300">()</span></label>
<input :value="interval.max_tokens ?? ''" @input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" :placeholder="'∞'" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', '输入') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$/M</span></label>
<input :value="interval.input_price" @input="emitField('input_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', '输出') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$/M</span></label>
<input :value="interval.output_price" @input="emitField('output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', '缓存W') }} <span class="text-gray-300">$/M</span></label>
<input :value="interval.cache_write_price" @input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', '缓存R') }} <span class="text-gray-300">$/M</span></label>
<input :value="interval.cache_read_price" @input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
</template>
<!-- Per-request / Image mode: tier label + context range + price -->
<template v-else>
<div class="w-24">
<label class="text-xs text-gray-400">
{{ mode === 'image' ? t('admin.channels.form.resolution', '分辨率') : t('admin.channels.form.tierLabel', '层级') }}
</label>
<input :value="interval.tier_label" @input="emitField('tier_label', ($event.target as HTMLInputElement).value)"
type="text" class="input mt-0.5 text-xs" :placeholder="mode === 'image' ? '1K / 2K / 4K' : ''" />
</div>
<div class="w-20">
<label class="text-xs text-gray-400">Min</label>
<input :value="interval.min_tokens" @input="emitField('min_tokens', toInt(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" />
</div>
<div class="w-20">
<label class="text-xs text-gray-400">Max <span class="text-gray-300">()</span></label>
<input :value="interval.max_tokens ?? ''" @input="emitField('max_tokens', toIntOrNull(($event.target as HTMLInputElement).value))"
type="number" min="0" class="input mt-0.5 text-xs" :placeholder="'∞'" />
</div>
<div class="flex-1">
<label class="text-xs text-gray-400">{{ t('admin.channels.form.perRequestPrice', '单次价格') }} <span v-if="isEmpty" class="text-red-500">*</span> <span class="text-gray-300">$</span></label>
<input :value="interval.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-xs" />
</div>
</template>
<button type="button" @click="emit('remove')" class="mt-4 rounded p-0.5 text-gray-400 hover:text-red-500">
<Icon name="x" size="sm" />
</button>
</div>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import Icon from '@/components/icons/Icon.vue'
import type { IntervalFormEntry } from './types'
import type { BillingMode } from '@/api/admin/channels'
const { t } = useI18n()
const props = defineProps<{
interval: IntervalFormEntry
mode: BillingMode
}>()
const emit = defineEmits<{
update: [interval: IntervalFormEntry]
remove: []
}>()
//
const isEmpty = computed(() => {
const iv = props.interval
return (iv.input_price == null || iv.input_price === '') &&
(iv.output_price == null || iv.output_price === '') &&
(iv.cache_write_price == null || iv.cache_write_price === '') &&
(iv.cache_read_price == null || iv.cache_read_price === '') &&
(iv.per_request_price == null || iv.per_request_price === '')
})
function emitField(field: keyof IntervalFormEntry, value: string | number | null) {
emit('update', { ...props.interval, [field]: value === '' ? null : value })
}
function toInt(val: string): number {
const n = parseInt(val, 10)
return isNaN(n) ? 0 : n
}
function toIntOrNull(val: string): number | null {
if (val === '') return null
const n = parseInt(val, 10)
return isNaN(n) ? null : n
}
</script>

View File

@ -0,0 +1,89 @@
<template>
<div>
<!-- Tags display -->
<div class="flex flex-wrap gap-1.5 rounded-lg border border-gray-200 bg-white p-2 dark:border-dark-600 dark:bg-dark-800 min-h-[2.5rem]">
<span
v-for="(model, idx) in models"
:key="idx"
class="inline-flex items-center gap-1 rounded-md px-2 py-0.5 text-sm"
:class="getPlatformTagClass(props.platform || '')"
>
{{ model }}
<button
type="button"
@click="removeModel(idx)"
class="ml-0.5 rounded-full p-0.5 hover:bg-primary-200 dark:hover:bg-primary-800"
>
<Icon name="x" size="xs" />
</button>
</span>
<input
ref="inputRef"
v-model="inputValue"
type="text"
class="flex-1 min-w-[120px] border-none bg-transparent text-sm outline-none placeholder:text-gray-400 dark:text-white"
:placeholder="models.length === 0 ? placeholder : ''"
@keydown.enter.prevent="addModel"
@keydown.tab.prevent="addModel"
@keydown.delete="handleBackspace"
@paste="handlePaste"
/>
</div>
<p class="mt-1 text-xs text-gray-400">
{{ t('admin.channels.form.modelInputHint', 'Press Enter to add, supports paste for batch import.') }}
</p>
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import { useI18n } from 'vue-i18n'
import Icon from '@/components/icons/Icon.vue'
import { getPlatformTagClass } from './types'
const { t } = useI18n()
const props = defineProps<{
models: string[]
placeholder?: string
platform?: string
}>()
const emit = defineEmits<{
'update:models': [models: string[]]
}>()
const inputValue = ref('')
const inputRef = ref<HTMLInputElement>()
function addModel() {
const val = inputValue.value.trim()
if (!val) return
if (!props.models.includes(val)) {
emit('update:models', [...props.models, val])
}
inputValue.value = ''
}
function removeModel(idx: number) {
const newModels = [...props.models]
newModels.splice(idx, 1)
emit('update:models', newModels)
}
function handleBackspace() {
if (inputValue.value === '' && props.models.length > 0) {
removeModel(props.models.length - 1)
}
}
function handlePaste(e: ClipboardEvent) {
e.preventDefault()
const text = e.clipboardData?.getData('text') || ''
const items = text.split(/[,\n;]+/).map(s => s.trim()).filter(Boolean)
if (items.length === 0) return
const unique = [...new Set([...props.models, ...items])]
emit('update:models', unique)
inputValue.value = ''
}
</script>

View File

@ -0,0 +1,354 @@
<template>
<div class="rounded-lg border border-gray-200 bg-gray-50 p-3 dark:border-dark-600 dark:bg-dark-800">
<!-- Collapsed summary header (clickable) -->
<div
class="flex cursor-pointer select-none items-center gap-2"
@click="collapsed = !collapsed"
>
<Icon
:name="collapsed ? 'chevronRight' : 'chevronDown'"
size="sm"
:stroke-width="2"
class="flex-shrink-0 text-gray-400 transition-transform duration-200"
/>
<!-- Summary: model tags + billing badge -->
<div v-if="collapsed" class="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
<!-- Compact model tags (show first 3) -->
<div class="flex min-w-0 flex-1 flex-wrap items-center gap-1">
<span
v-for="(m, i) in entry.models.slice(0, 3)"
:key="i"
class="inline-flex shrink-0 rounded px-1.5 py-0.5 text-xs"
:class="getPlatformTagClass(props.platform || '')"
>
{{ m }}
</span>
<span
v-if="entry.models.length > 3"
class="whitespace-nowrap text-xs text-gray-400"
>
+{{ entry.models.length - 3 }}
</span>
<span
v-if="entry.models.length === 0"
class="text-xs italic text-gray-400"
>
{{ t('admin.channels.form.noModels', '未添加模型') }}
</span>
</div>
<!-- Billing mode badge -->
<span
class="flex-shrink-0 rounded-full bg-primary-100 px-2 py-0.5 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
>
{{ billingModeLabel }}
</span>
</div>
<!-- Expanded: show the label "Pricing Entry" or similar -->
<div v-else class="flex-1 text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.pricingEntry', '定价配置') }}
</div>
<!-- Remove button (always visible, stop propagation) -->
<button
type="button"
@click.stop="emit('remove')"
class="flex-shrink-0 rounded p-1 text-gray-400 hover:text-red-500"
>
<Icon name="trash" size="sm" />
</button>
</div>
<!-- Expandable content with transition -->
<div
class="collapsible-content"
:class="{ 'collapsible-content--collapsed': collapsed }"
>
<div class="collapsible-inner">
<!-- Header: Models + Billing Mode -->
<div class="mt-3 flex items-start gap-2">
<div class="flex-1">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.models', '模型列表') }} <span class="text-red-500">*</span>
</label>
<ModelTagInput
:models="entry.models"
:platform="props.platform"
@update:models="onModelsUpdate($event)"
:placeholder="t('admin.channels.form.modelsPlaceholder', '输入模型名后按回车添加,支持通配符 *')"
class="mt-1"
/>
</div>
<div class="w-40">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.billingMode', '计费模式') }}
</label>
<Select
:modelValue="entry.billing_mode"
@update:modelValue="emit('update', { ...entry, billing_mode: $event as BillingMode, intervals: [] })"
:options="billingModeOptions"
class="mt-1"
/>
</div>
</div>
<!-- Token mode -->
<div v-if="entry.billing_mode === 'token'">
<!-- Default prices (fallback when no interval matches) -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultPrices', '默认价格(未命中区间时使用)') }}
<span class="ml-1 font-normal text-gray-400">$/MTok</span>
</label>
<div class="mt-1 grid grid-cols-2 gap-2 sm:grid-cols-5">
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.inputPrice', '输入') }}</label>
<input :value="entry.input_price" @input="emitField('input_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.outputPrice', '输出') }}</label>
<input :value="entry.output_price" @input="emitField('output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheWritePrice', '缓存写入') }}</label>
<input :value="entry.cache_write_price" @input="emitField('cache_write_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.cacheReadPrice', '缓存读取') }}</label>
<input :value="entry.cache_read_price" @input="emitField('cache_read_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.imageTokenPrice', '图片输出') }}</label>
<input :value="entry.image_output_price" @input="emitField('image_output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
</div>
<!-- Token intervals -->
<div class="mt-3">
<div class="flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.intervals', '上下文区间定价(可选)') }}
<span class="ml-1 font-normal text-gray-400">(min, max]</span>
</label>
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('admin.channels.form.addInterval', '添加区间') }}
</button>
</div>
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
<IntervalRow
v-for="(iv, idx) in entry.intervals"
:key="idx"
:interval="iv"
:mode="entry.billing_mode"
@update="updateInterval(idx, $event)"
@remove="removeInterval(idx)"
/>
</div>
</div>
</div>
<!-- Per-request mode -->
<div v-else-if="entry.billing_mode === 'per_request'">
<!-- Default per-request price -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultPerRequestPrice', '默认单次价格(未命中层级时使用)') }}
<span class="ml-1 font-normal text-gray-400">$</span>
</label>
<div class="mt-1 w-48">
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<!-- Tiers -->
<div class="mt-3 flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.requestTiers', '按次计费层级') }}
</label>
<button type="button" @click="addInterval" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('admin.channels.form.addTier', '添加层级') }}
</button>
</div>
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
<IntervalRow
v-for="(iv, idx) in entry.intervals"
:key="idx"
:interval="iv"
:mode="entry.billing_mode"
@update="updateInterval(idx, $event)"
@remove="removeInterval(idx)"
/>
</div>
<div v-else class="mt-2 rounded border border-dashed border-gray-300 p-3 text-center text-xs text-gray-400 dark:border-dark-500">
{{ t('admin.channels.form.noTiersYet', '暂无层级,点击添加配置按次计费价格') }}
</div>
</div>
<!-- Image mode -->
<div v-else-if="entry.billing_mode === 'image'">
<!-- Default image price (per-request, same as per_request mode) -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultImagePrice', '默认图片价格(未命中层级时使用)') }}
<span class="ml-1 font-normal text-gray-400">$</span>
</label>
<div class="mt-1 w-48">
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<!-- Image tiers -->
<div class="mt-3 flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.imageTiers', '图片计费层级(按次)') }}
</label>
<button type="button" @click="addImageTier" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('admin.channels.form.addTier', '添加层级') }}
</button>
</div>
<div v-if="entry.intervals && entry.intervals.length > 0" class="mt-2 space-y-2">
<IntervalRow
v-for="(iv, idx) in entry.intervals"
:key="idx"
:interval="iv"
:mode="entry.billing_mode"
@update="updateInterval(idx, $event)"
@remove="removeInterval(idx)"
/>
</div>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
import IntervalRow from './IntervalRow.vue'
import ModelTagInput from './ModelTagInput.vue'
import type { PricingFormEntry, IntervalFormEntry } from './types'
import { perTokenToMTok, getPlatformTagClass } from './types'
import type { BillingMode } from '@/api/admin/channels'
import channelsAPI from '@/api/admin/channels'
const { t } = useI18n()
const props = defineProps<{
entry: PricingFormEntry
platform?: string
}>()
const emit = defineEmits<{
update: [entry: PricingFormEntry]
remove: []
}>()
// Collapse state: entries with existing models default to collapsed
const collapsed = ref(props.entry.models.length > 0)
const billingModeOptions = computed(() => [
{ value: 'token', label: 'Token' },
{ value: 'per_request', label: t('admin.channels.billingMode.perRequest', '按次') },
{ value: 'image', label: t('admin.channels.billingMode.image', '图片(按次)') }
])
const billingModeLabel = computed(() => {
const opt = billingModeOptions.value.find(o => o.value === props.entry.billing_mode)
return opt ? opt.label : props.entry.billing_mode
})
function emitField(field: keyof PricingFormEntry, value: string) {
emit('update', { ...props.entry, [field]: value === '' ? null : value })
}
function addInterval() {
const intervals = [...(props.entry.intervals || [])]
intervals.push({
min_tokens: 0, max_tokens: null, tier_label: '',
input_price: null, output_price: null, cache_write_price: null,
cache_read_price: null, per_request_price: null,
sort_order: intervals.length
})
emit('update', { ...props.entry, intervals })
}
function addImageTier() {
const intervals = [...(props.entry.intervals || [])]
const labels = ['1K', '2K', '4K', 'HD']
intervals.push({
min_tokens: 0, max_tokens: null, tier_label: labels[intervals.length] || '',
input_price: null, output_price: null, cache_write_price: null,
cache_read_price: null, per_request_price: null,
sort_order: intervals.length
})
emit('update', { ...props.entry, intervals })
}
function updateInterval(idx: number, updated: IntervalFormEntry) {
const intervals = [...(props.entry.intervals || [])]
intervals[idx] = updated
emit('update', { ...props.entry, intervals })
}
function removeInterval(idx: number) {
const intervals = [...(props.entry.intervals || [])]
intervals.splice(idx, 1)
emit('update', { ...props.entry, intervals })
}
async function onModelsUpdate(newModels: string[]) {
const oldModels = props.entry.models
emit('update', { ...props.entry, models: newModels })
//
const addedModels = newModels.filter(m => !oldModels.includes(m))
if (addedModels.length === 0) return
//
const e = props.entry
const hasPrice = e.input_price != null || e.output_price != null ||
e.cache_write_price != null || e.cache_read_price != null
if (hasPrice) return
//
try {
const result = await channelsAPI.getModelDefaultPricing(addedModels[0])
if (result.found) {
emit('update', {
...props.entry,
models: newModels,
input_price: perTokenToMTok(result.input_price ?? null),
output_price: perTokenToMTok(result.output_price ?? null),
cache_write_price: perTokenToMTok(result.cache_write_price ?? null),
cache_read_price: perTokenToMTok(result.cache_read_price ?? null),
image_output_price: perTokenToMTok(result.image_output_price ?? null),
})
}
} catch {
//
}
}
</script>
<style scoped>
.collapsible-content {
display: grid;
grid-template-rows: 1fr;
transition: grid-template-rows 0.25s ease;
}
.collapsible-content--collapsed {
grid-template-rows: 0fr;
}
.collapsible-inner {
overflow: hidden;
}
</style>

View File

@ -0,0 +1,190 @@
import type { BillingMode, PricingInterval } from '@/api/admin/channels'
export interface IntervalFormEntry {
min_tokens: number
max_tokens: number | null
tier_label: string
input_price: number | string | null
output_price: number | string | null
cache_write_price: number | string | null
cache_read_price: number | string | null
per_request_price: number | string | null
sort_order: number
}
export interface PricingFormEntry {
models: string[]
billing_mode: BillingMode
input_price: number | string | null
output_price: number | string | null
cache_write_price: number | string | null
cache_read_price: number | string | null
image_output_price: number | string | null
per_request_price: number | string | null
intervals: IntervalFormEntry[]
}
// 价格转换:后端存 per-token前端显示 per-MTok ($/1M tokens)
const MTOK = 1_000_000
export function toNullableNumber(val: number | string | null | undefined): number | null {
if (val === null || val === undefined || val === '') return null
const num = Number(val)
return isNaN(num) ? null : num
}
/** 前端显示值($/MTok) → 后端存储值(per-token) */
export function mTokToPerToken(val: number | string | null | undefined): number | null {
const num = toNullableNumber(val)
return num === null ? null : parseFloat((num / MTOK).toPrecision(10))
}
/** 后端存储值(per-token) → 前端显示值($/MTok) */
export function perTokenToMTok(val: number | null | undefined): number | null {
if (val === null || val === undefined) return null
// toPrecision(10) 消除 IEEE 754 浮点乘法精度误差,如 5e-8 * 1e6 = 0.04999...96 → 0.05
return parseFloat((val * MTOK).toPrecision(10))
}
export function apiIntervalsToForm(intervals: PricingInterval[]): IntervalFormEntry[] {
return (intervals || []).map(iv => ({
min_tokens: iv.min_tokens,
max_tokens: iv.max_tokens,
tier_label: iv.tier_label || '',
input_price: perTokenToMTok(iv.input_price),
output_price: perTokenToMTok(iv.output_price),
cache_write_price: perTokenToMTok(iv.cache_write_price),
cache_read_price: perTokenToMTok(iv.cache_read_price),
per_request_price: iv.per_request_price,
sort_order: iv.sort_order
}))
}
export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInterval[] {
return (intervals || []).map(iv => ({
min_tokens: iv.min_tokens,
max_tokens: iv.max_tokens,
tier_label: iv.tier_label,
input_price: mTokToPerToken(iv.input_price),
output_price: mTokToPerToken(iv.output_price),
cache_write_price: mTokToPerToken(iv.cache_write_price),
cache_read_price: mTokToPerToken(iv.cache_read_price),
per_request_price: toNullableNumber(iv.per_request_price),
sort_order: iv.sort_order
}))
}
// ── 模型模式冲突检测 ──────────────────────────────────────
interface ModelPattern {
pattern: string
prefix: string // lowercase, 通配符去掉尾部 *
wildcard: boolean
}
function toModelPattern(model: string): ModelPattern {
const lower = model.toLowerCase()
const wildcard = lower.endsWith('*')
return {
pattern: model,
prefix: wildcard ? lower.slice(0, -1) : lower,
wildcard,
}
}
function patternsConflict(a: ModelPattern, b: ModelPattern): boolean {
if (!a.wildcard && !b.wildcard) return a.prefix === b.prefix
if (a.wildcard && !b.wildcard) return b.prefix.startsWith(a.prefix)
if (!a.wildcard && b.wildcard) return a.prefix.startsWith(b.prefix)
// 双通配符:任一前缀是另一前缀的前缀即冲突
return a.prefix.startsWith(b.prefix) || b.prefix.startsWith(a.prefix)
}
/** 检测模型模式列表中的冲突,返回冲突的两个模式名;无冲突返回 null */
export function findModelConflict(models: string[]): [string, string] | null {
const patterns = models.map(toModelPattern)
for (let i = 0; i < patterns.length; i++) {
for (let j = i + 1; j < patterns.length; j++) {
if (patternsConflict(patterns[i], patterns[j])) {
return [patterns[i].pattern, patterns[j].pattern]
}
}
}
return null
}
// ── 区间校验 ──────────────────────────────────────────────
/** 校验区间列表的合法性,返回错误消息;通过则返回 null */
export function validateIntervals(intervals: IntervalFormEntry[]): string | null {
if (!intervals || intervals.length === 0) return null
// 按 min_tokens 排序(不修改原数组)
const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens)
for (let i = 0; i < sorted.length; i++) {
const err = validateSingleInterval(sorted[i], i)
if (err) return err
}
return checkIntervalOverlap(sorted)
}
function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null {
if (iv.min_tokens < 0) {
return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数`
}
if (iv.max_tokens != null) {
if (iv.max_tokens <= 0) {
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0`
}
if (iv.max_tokens <= iv.min_tokens) {
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})`
}
}
return validateIntervalPrices(iv, idx)
}
function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null {
const prices: [string, number | string | null][] = [
['输入价格', iv.input_price],
['输出价格', iv.output_price],
['缓存写入价格', iv.cache_write_price],
['缓存读取价格', iv.cache_read_price],
['单次价格', iv.per_request_price],
]
for (const [name, val] of prices) {
if (val != null && val !== '' && Number(val) < 0) {
return `区间 #${idx + 1}: ${name}不能为负数`
}
}
return null
}
function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null {
for (let i = 0; i < sorted.length; i++) {
// 无上限区间必须是最后一个
if (sorted[i].max_tokens == null && i < sorted.length - 1) {
return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个`
}
if (i === 0) continue
const prev = sorted[i - 1]
// (min, max] 语义:前一个区间上界 > 当前区间下界则重叠
if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) {
const prevMax = prev.max_tokens == null ? '∞' : String(prev.max_tokens)
return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})`
}
}
return null
}
/** 平台对应的模型 tag 样式(背景+文字) */
export function getPlatformTagClass(platform: string): string {
switch (platform) {
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
case 'sora': return 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400'
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
}
}

View File

@ -133,6 +133,12 @@
<Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" /> <Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" />
</div> </div>
<!-- Billing Mode Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.billingMode') }}</label>
<Select v-model="filters.billing_mode" :options="billingModeOptions" @change="emitChange" />
</div>
<!-- Group Filter --> <!-- Group Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]"> <div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.group') }}</label> <label class="input-label">{{ t('admin.usage.group') }}</label>
@ -232,6 +238,13 @@ const billingTypeOptions = ref<SelectOption[]>([
{ value: 1, label: t('admin.usage.billingTypeSubscription') } { value: 1, label: t('admin.usage.billingTypeSubscription') }
]) ])
const billingModeOptions = ref<SelectOption[]>([
{ value: null, label: t('admin.usage.allBillingModes') },
{ value: 'token', label: t('admin.usage.billingModeToken') },
{ value: 'per_request', label: t('admin.usage.billingModePerRequest') },
{ value: 'image', label: t('admin.usage.billingModeImage') }
])
const emitChange = () => emit('change') const emitChange = () => emit('change')
const debounceUserSearch = () => { const debounceUserSearch = () => {

View File

@ -26,7 +26,15 @@
</template> </template>
<template #cell-model="{ row }"> <template #cell-model="{ row }">
<div v-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs"> <div v-if="row.model_mapping_chain && row.model_mapping_chain.includes('→')" class="space-y-0.5 text-xs">
<div v-for="(step, i) in row.model_mapping_chain.split('→')" :key="i"
class="break-all"
:class="i === 0 ? 'font-medium text-gray-900 dark:text-white' : 'text-gray-500 dark:text-gray-400'"
:style="i > 0 ? `padding-left: ${i * 0.75}rem` : ''">
<span v-if="i > 0" class="mr-0.5"></span>{{ step }}
</div>
</div>
<div v-else-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
<div class="break-all font-medium text-gray-900 dark:text-white"> <div class="break-all font-medium text-gray-900 dark:text-white">
{{ row.model }} {{ row.model }}
</div> </div>
@ -69,9 +77,15 @@
</span> </span>
</template> </template>
<template #cell-billing_mode="{ row }">
<span class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium" :class="getBillingModeBadgeClass(row.billing_mode)">
{{ getBillingModeLabel(row.billing_mode) }}
</span>
</template>
<template #cell-tokens="{ row }"> <template #cell-tokens="{ row }">
<!-- 图片生成请求 --> <!-- 图片生成请求仅按次计费时显示图片格式 -->
<div v-if="row.image_count > 0" class="flex items-center gap-1.5"> <div v-if="row.image_count > 0 && row.billing_mode === 'image'" class="flex items-center gap-1.5">
<svg class="h-4 w-4 text-indigo-500" fill="none" stroke="currentColor" viewBox="0 0 24 24"> <svg class="h-4 w-4 text-indigo-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z" /> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z" />
</svg> </svg>
@ -281,11 +295,11 @@
</div> </div>
<div class="flex items-center justify-between gap-6"> <div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.rate') }}</span> <span class="text-gray-400">{{ t('usage.rate') }}</span>
<span class="font-semibold text-blue-400">{{ (tooltipData?.rate_multiplier || 1).toFixed(2) }}x</span> <span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
</div> </div>
<div class="flex items-center justify-between gap-6"> <div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span> <span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
<span class="font-semibold text-blue-400">{{ (tooltipData?.account_rate_multiplier ?? 1).toFixed(2) }}x</span> <span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
</div> </div>
<div class="flex items-center justify-between gap-6"> <div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.original') }}</span> <span class="text-gray-400">{{ t('usage.original') }}</span>
@ -312,6 +326,7 @@
import { ref } from 'vue' import { ref } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { formatDateTime, formatReasoningEffort } from '@/utils/format' import { formatDateTime, formatReasoningEffort } from '@/utils/format'
import { formatCacheTokens, formatMultiplier } from '@/utils/formatters'
import { formatTokenPricePerMillion } from '@/utils/usagePricing' import { formatTokenPricePerMillion } from '@/utils/usagePricing'
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier' import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
import { resolveUsageRequestType } from '@/utils/usageRequestType' import { resolveUsageRequestType } from '@/utils/usageRequestType'
@ -350,12 +365,19 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => {
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
} }
const formatCacheTokens = (tokens: number): string => { const getBillingModeLabel = (mode: string | null | undefined): string => {
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M` if (mode === 'per_request') return t('admin.usage.billingModePerRequest')
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K` if (mode === 'image') return t('admin.usage.billingModeImage')
return tokens.toString() return t('admin.usage.billingModeToken')
} }
const getBillingModeBadgeClass = (mode: string | null | undefined): string => {
if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200'
if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200'
return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
}
const formatUserAgent = (ua: string): string => { const formatUserAgent = (ua: string): string => {
return ua return ua
} }

View File

@ -161,6 +161,7 @@ const props = withDefaults(
showSourceToggle?: boolean showSourceToggle?: boolean
startDate?: string startDate?: string
endDate?: string endDate?: string
filters?: Record<string, any>
}>(), }>(),
{ {
upstreamEndpointStats: () => [], upstreamEndpointStats: () => [],
@ -193,6 +194,7 @@ const toggleBreakdown = async (endpoint: string) => {
breakdownItems.value = [] breakdownItems.value = []
try { try {
const res = await getUserBreakdown({ const res = await getUserBreakdown({
...props.filters,
start_date: props.startDate, start_date: props.startDate,
end_date: props.endDate, end_date: props.endDate,
endpoint, endpoint,

View File

@ -125,6 +125,7 @@ const props = withDefaults(defineProps<{
showMetricToggle?: boolean showMetricToggle?: boolean
startDate?: string startDate?: string
endDate?: string endDate?: string
filters?: Record<string, any>
}>(), { }>(), {
loading: false, loading: false,
metric: 'tokens', metric: 'tokens',
@ -150,6 +151,7 @@ const toggleBreakdown = async (type: string, id: number | string) => {
breakdownItems.value = [] breakdownItems.value = []
try { try {
const res = await getUserBreakdown({ const res = await getUserBreakdown({
...props.filters,
start_date: props.startDate, start_date: props.startDate,
end_date: props.endDate, end_date: props.endDate,
group_id: Number(id), group_id: Number(id),

View File

@ -270,6 +270,7 @@ const props = withDefaults(defineProps<{
rankingError?: boolean rankingError?: boolean
startDate?: string startDate?: string
endDate?: string endDate?: string
filters?: Record<string, any>
}>(), { }>(), {
upstreamModelStats: () => [], upstreamModelStats: () => [],
mappingModelStats: () => [], mappingModelStats: () => [],
@ -302,6 +303,7 @@ const toggleBreakdown = async (type: string, id: string) => {
breakdownItems.value = [] breakdownItems.value = []
try { try {
const res = await getUserBreakdown({ const res = await getUserBreakdown({
...props.filters,
start_date: props.startDate, start_date: props.startDate,
end_date: props.endDate, end_date: props.endDate,
model: id, model: id,

View File

@ -287,6 +287,21 @@ const FolderIcon = {
) )
} }
const ChannelIcon = {
render: () =>
h(
'svg',
{ fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
[
h('path', {
'stroke-linecap': 'round',
'stroke-linejoin': 'round',
d: 'M6.429 9.75L2.25 12l4.179 2.25m0-4.5l5.571 3 5.571-3m-11.142 0L2.25 7.5 12 2.25l9.75 5.25-4.179 2.25m0 0l4.179 2.25L12 17.25 2.25 12m15.321-2.25l4.179 2.25L12 17.25l-9.75-5.25'
})
]
)
}
const CreditCardIcon = { const CreditCardIcon = {
render: () => render: () =>
h( h(
@ -568,6 +583,7 @@ const adminNavItems = computed((): NavItem[] => {
: []), : []),
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true }, { path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true }, { path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, { path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
{ path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon }, { path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon },
{ path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon }, { path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon },

View File

@ -335,6 +335,7 @@ export default {
profile: 'Profile', profile: 'Profile',
users: 'Users', users: 'Users',
groups: 'Groups', groups: 'Groups',
channels: 'Channels',
subscriptions: 'Subscriptions', subscriptions: 'Subscriptions',
accounts: 'Accounts', accounts: 'Accounts',
proxies: 'Proxies', proxies: 'Proxies',
@ -1719,6 +1720,107 @@ export default {
} }
}, },
// Channel Management
channels: {
title: 'Channel Management',
description: 'Manage channels and custom model pricing',
searchChannels: 'Search channels...',
createChannel: 'Create Channel',
editChannel: 'Edit Channel',
deleteChannel: 'Delete Channel',
statusActive: 'Active',
statusDisabled: 'Disabled',
allStatus: 'All Status',
groupsUnit: 'groups',
pricingUnit: 'pricing rules',
noChannelsYet: 'No Channels Yet',
createFirstChannel: 'Create your first channel to manage model pricing',
loadError: 'Failed to load channels',
createSuccess: 'Channel created',
updateSuccess: 'Channel updated',
deleteSuccess: 'Channel deleted',
createError: 'Failed to create channel',
updateError: 'Failed to update channel',
deleteError: 'Failed to delete channel',
nameRequired: 'Please enter a channel name',
duplicateModels: 'Model "{0}" appears in multiple pricing entries',
modelConflict: "Model patterns '{model1}' and '{model2}' conflict: overlapping match range",
mappingConflict: "Mapping source patterns '{model1}' and '{model2}' conflict: overlapping match range",
deleteConfirm: 'Are you sure you want to delete channel "{name}"? This cannot be undone.',
columns: {
name: 'Name',
description: 'Description',
status: 'Status',
groups: 'Groups',
pricing: 'Pricing',
createdAt: 'Created',
actions: 'Actions'
},
billingMode: {
token: 'Token',
perRequest: 'Per Request',
image: 'Image (Per Request)'
},
form: {
name: 'Name',
namePlaceholder: 'Enter channel name',
description: 'Description',
descriptionPlaceholder: 'Optional description',
status: 'Status',
groups: 'Associated Groups',
noGroupsAvailable: 'No groups available',
inOtherChannel: 'In "{name}"',
modelPricing: 'Model Pricing',
models: 'Models',
modelsPlaceholder: 'Type full model name and press Enter',
modelInputHint: 'Press Enter to add, supports paste for batch import.',
billingMode: 'Billing Mode',
defaultPrices: 'Default prices (fallback when no interval matches)',
inputPrice: 'Input',
outputPrice: 'Output',
cacheWritePrice: 'Cache Write',
cacheReadPrice: 'Cache Read',
imageTokenPrice: 'Image Output',
imageOutputPrice: 'Image Output Price',
pricePlaceholder: 'Default',
intervals: 'Context Intervals (optional)',
addInterval: 'Add Interval',
requestTiers: 'Request Tiers',
imageTiers: 'Image Tiers (Per Request)',
addTier: 'Add Tier',
noTiersYet: 'No tiers yet. Click add to configure per-request pricing.',
noPricingRules: 'No pricing rules yet. Click "Add" to create one.',
perRequestPrice: 'Price per Request',
perRequestPriceRequired: 'Per-request price or billing tiers required for per-request/image billing mode',
tierLabel: 'Tier',
resolution: 'Resolution',
modelMapping: 'Model Mapping',
modelMappingHint: 'Map request model names to actual model names. Runs before account-level mapping.',
noMappingRules: 'No mapping rules. Click "Add" to create one.',
mappingSource: 'Source model',
mappingTarget: 'Target model',
billingModelSource: 'Billing Model',
billingModelSourceChannelMapped: 'Bill by channel-mapped model',
billingModelSourceRequested: 'Bill by requested model',
billingModelSourceUpstream: 'Bill by final upstream model',
billingModelSourceHint: 'Controls which model name is used for pricing lookup',
selectedCount: '{count} selected',
searchGroups: 'Search groups...',
noGroupsMatch: 'No groups match your search',
restrictModels: 'Restrict Models',
restrictModelsHint: 'When enabled, only models in the pricing list are allowed. Others will be rejected.',
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
defaultImagePrice: 'Default image price (fallback when no tier matches)',
platformConfig: 'Platform Configuration',
basicSettings: 'Basic Settings',
addPlatform: 'Add Platform',
noPlatforms: 'Click "Add Platform" to start configuring the channel',
mappingCount: 'mappings',
pricingEntry: 'Pricing Entry',
noModels: 'No models added'
}
},
// Subscriptions // Subscriptions
subscriptions: { subscriptions: {
title: 'Subscription Management', title: 'Subscription Management',
@ -3258,6 +3360,11 @@ export default {
allBillingTypes: 'All Billing Types', allBillingTypes: 'All Billing Types',
billingTypeBalance: 'Balance', billingTypeBalance: 'Balance',
billingTypeSubscription: 'Subscription', billingTypeSubscription: 'Subscription',
billingMode: 'Billing Mode',
billingModeToken: 'Token',
billingModePerRequest: 'Per Request',
billingModeImage: 'Image',
allBillingModes: 'All Billing Modes',
ipAddress: 'IP', ipAddress: 'IP',
clickToViewBalance: 'Click to view balance history', clickToViewBalance: 'Click to view balance history',
failedToLoadUser: 'Failed to load user info', failedToLoadUser: 'Failed to load user info',

View File

@ -335,6 +335,7 @@ export default {
profile: '个人资料', profile: '个人资料',
users: '用户管理', users: '用户管理',
groups: '分组管理', groups: '分组管理',
channels: '渠道管理',
subscriptions: '订阅管理', subscriptions: '订阅管理',
accounts: '账号管理', accounts: '账号管理',
proxies: 'IP管理', proxies: 'IP管理',
@ -1799,6 +1800,107 @@ export default {
} }
}, },
// Channel Management
channels: {
title: '渠道管理',
description: '管理渠道和自定义模型定价',
searchChannels: '搜索渠道...',
createChannel: '创建渠道',
editChannel: '编辑渠道',
deleteChannel: '删除渠道',
statusActive: '启用',
statusDisabled: '停用',
allStatus: '全部状态',
groupsUnit: '个分组',
pricingUnit: '条定价',
noChannelsYet: '暂无渠道',
createFirstChannel: '创建第一个渠道来管理模型定价',
loadError: '加载渠道列表失败',
createSuccess: '渠道创建成功',
updateSuccess: '渠道更新成功',
deleteSuccess: '渠道删除成功',
createError: '创建渠道失败',
updateError: '更新渠道失败',
deleteError: '删除渠道失败',
nameRequired: '请输入渠道名称',
duplicateModels: '模型「{0}」在多个定价条目中重复',
modelConflict: "模型模式 '{model1}' 和 '{model2}' 冲突:匹配范围重叠",
mappingConflict: "模型映射源 '{model1}' 和 '{model2}' 冲突:匹配范围重叠",
deleteConfirm: '确定要删除渠道「{name}」吗?此操作不可撤销。',
columns: {
name: '名称',
description: '描述',
status: '状态',
groups: '分组',
pricing: '定价',
createdAt: '创建时间',
actions: '操作'
},
billingMode: {
token: 'Token',
perRequest: '按次',
image: '图片(按次)'
},
form: {
name: '名称',
namePlaceholder: '输入渠道名称',
description: '描述',
descriptionPlaceholder: '可选描述',
status: '状态',
groups: '关联分组',
noGroupsAvailable: '暂无可用分组',
inOtherChannel: '已属于「{name}」',
modelPricing: '模型定价',
models: '模型列表',
modelsPlaceholder: '输入完整模型名后按回车添加',
modelInputHint: '按回车添加,支持粘贴批量导入',
billingMode: '计费模式',
defaultPrices: '默认价格(未命中区间时使用)',
inputPrice: '输入',
outputPrice: '输出',
cacheWritePrice: '缓存写入',
cacheReadPrice: '缓存读取',
imageTokenPrice: '图片输出',
imageOutputPrice: '图片输出价格',
pricePlaceholder: '默认',
intervals: '上下文区间定价(可选)',
addInterval: '添加区间',
requestTiers: '按次计费层级',
imageTiers: '图片计费层级(按次)',
addTier: '添加层级',
noTiersYet: '暂无层级,点击添加配置按次计费价格',
noPricingRules: '暂无定价规则,点击"添加"创建',
perRequestPrice: '单次价格',
perRequestPriceRequired: '按次/图片计费模式必须设置默认价格或至少一个计费层级',
tierLabel: '层级',
resolution: '分辨率',
modelMapping: '模型映射',
modelMappingHint: '将请求中的模型名映射为实际模型名。在账号级别映射之前执行。',
noMappingRules: '暂无映射规则,点击"添加"创建',
mappingSource: '源模型',
mappingTarget: '目标模型',
billingModelSource: '计费基准',
billingModelSourceChannelMapped: '以渠道映射后的模型计费',
billingModelSourceRequested: '以请求模型计费',
billingModelSourceUpstream: '以最终模型计费',
billingModelSourceHint: '控制使用哪个模型名称进行定价查找',
selectedCount: '已选 {count} 个',
searchGroups: '搜索分组...',
noGroupsMatch: '没有匹配的分组',
restrictModels: '限制模型',
restrictModelsHint: '开启后,仅允许模型定价列表中的模型。不在列表中的模型请求将被拒绝。',
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
defaultImagePrice: '默认图片价格(未命中层级时使用)',
platformConfig: '平台配置',
basicSettings: '基础设置',
addPlatform: '添加平台',
noPlatforms: '点击"添加平台"开始配置渠道',
mappingCount: '条映射',
pricingEntry: '定价配置',
noModels: '未添加模型'
}
},
// Subscriptions Management // Subscriptions Management
subscriptions: { subscriptions: {
title: '订阅管理', title: '订阅管理',
@ -3417,6 +3519,11 @@ export default {
allBillingTypes: '全部计费类型', allBillingTypes: '全部计费类型',
billingTypeBalance: '钱包余额', billingTypeBalance: '钱包余额',
billingTypeSubscription: '订阅套餐', billingTypeSubscription: '订阅套餐',
billingMode: '计费模式',
billingModeToken: '按量',
billingModePerRequest: '按次',
billingModeImage: '按次(图片)',
allBillingModes: '全部计费模式',
ipAddress: 'IP', ipAddress: 'IP',
clickToViewBalance: '点击查看充值记录', clickToViewBalance: '点击查看充值记录',
failedToLoadUser: '加载用户信息失败', failedToLoadUser: '加载用户信息失败',

View File

@ -278,6 +278,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.groups.description' descriptionKey: 'admin.groups.description'
} }
}, },
{
path: '/admin/channels',
name: 'AdminChannels',
component: () => import('@/views/admin/ChannelsView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Channel Management',
titleKey: 'admin.channels.title',
descriptionKey: 'admin.channels.description'
}
},
{ {
path: '/admin/subscriptions', path: '/admin/subscriptions',
name: 'AdminSubscriptions', name: 'AdminSubscriptions',

View File

@ -1036,6 +1036,9 @@ export interface UsageLog {
// Cache TTL Override // Cache TTL Override
cache_ttl_overridden: boolean cache_ttl_overridden: boolean
// 计费模式
billing_mode?: string | null
created_at: string created_at: string
user?: User user?: User
@ -1051,6 +1054,7 @@ export interface UsageLogAccountSummary {
export interface AdminUsageLog extends UsageLog { export interface AdminUsageLog extends UsageLog {
upstream_model?: string | null upstream_model?: string | null
model_mapping_chain?: string | null
// 账号计费倍率(仅管理员可见) // 账号计费倍率(仅管理员可见)
account_rate_multiplier?: number | null account_rate_multiplier?: number | null

View File

@ -0,0 +1,18 @@
/**
* token 1K/1M
*/
export function formatCacheTokens(tokens: number): string {
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
return tokens.toLocaleString()
}
/**
* 0.001
*/
export function formatMultiplier(val: number): string {
if (val >= 0.01) return val.toFixed(2)
if (val >= 0.001) return val.toFixed(3)
if (val >= 0.0001) return val.toFixed(4)
return val.toPrecision(2)
}

File diff suppressed because it is too large Load Diff

View File

@ -34,6 +34,7 @@
:show-metric-toggle="true" :show-metric-toggle="true"
:start-date="startDate" :start-date="startDate"
:end-date="endDate" :end-date="endDate"
:filters="breakdownFilters"
/> />
<GroupDistributionChart <GroupDistributionChart
v-model:metric="groupDistributionMetric" v-model:metric="groupDistributionMetric"
@ -42,6 +43,7 @@
:show-metric-toggle="true" :show-metric-toggle="true"
:start-date="startDate" :start-date="startDate"
:end-date="endDate" :end-date="endDate"
:filters="breakdownFilters"
/> />
</div> </div>
<div class="grid grid-cols-1 gap-6 lg:grid-cols-2"> <div class="grid grid-cols-1 gap-6 lg:grid-cols-2">
@ -57,6 +59,7 @@
:title="t('usage.endpointDistribution')" :title="t('usage.endpointDistribution')"
:start-date="startDate" :start-date="startDate"
:end-date="endDate" :end-date="endDate"
:filters="breakdownFilters"
/> />
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" /> <TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
</div> </div>
@ -169,6 +172,17 @@ const cleanupDialogVisible = ref(false)
const showBalanceHistoryModal = ref(false) const showBalanceHistoryModal = ref(false)
const balanceHistoryUser = ref<AdminUser | null>(null) const balanceHistoryUser = ref<AdminUser | null>(null)
const breakdownFilters = computed(() => {
const f: Record<string, any> = {}
if (filters.value.user_id) f.user_id = filters.value.user_id
if (filters.value.api_key_id) f.api_key_id = filters.value.api_key_id
if (filters.value.account_id) f.account_id = filters.value.account_id
if (filters.value.group_id) f.group_id = filters.value.group_id
if (filters.value.request_type != null) f.request_type = filters.value.request_type
if (filters.value.billing_type != null) f.billing_type = filters.value.billing_type
return f
})
const handleUserClick = async (userId: number) => { const handleUserClick = async (userId: number) => {
try { try {
const user = await adminAPI.users.getById(userId) const user = await adminAPI.users.getById(userId)
@ -392,7 +406,7 @@ const resetFilters = () => {
const range = getLast24HoursRangeDates() const range = getLast24HoursRangeDates()
startDate.value = range.start startDate.value = range.start
endDate.value = range.end endDate.value = range.end
filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null } filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null, billing_mode: undefined }
granularity.value = getGranularityForRange(startDate.value, endDate.value) granularity.value = getGranularityForRange(startDate.value, endDate.value)
applyFilters() applyFilters()
} }
@ -440,7 +454,7 @@ const exportToExcel = async () => {
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens, log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000', log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000', log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
log.rate_multiplier?.toFixed(2) || '1.00', (log.account_rate_multiplier ?? 1).toFixed(2), log.rate_multiplier?.toPrecision(4) || '1.00', (log.account_rate_multiplier ?? 1).toPrecision(4),
log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000', log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms, (log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms,
log.request_id || '', log.user_agent || '', log.ip_address || '' log.request_id || '', log.user_agent || '', log.ip_address || ''
@ -477,6 +491,7 @@ const allColumns = computed(() => [
{ key: 'endpoint', label: t('usage.endpoint'), sortable: false }, { key: 'endpoint', label: t('usage.endpoint'), sortable: false },
{ key: 'group', label: t('admin.usage.group'), sortable: false }, { key: 'group', label: t('admin.usage.group'), sortable: false },
{ key: 'stream', label: t('usage.type'), sortable: false }, { key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'billing_mode', label: t('admin.usage.billingMode'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false },
{ key: 'cost', label: t('usage.cost'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false },
{ key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false },

Some files were not shown because too many files have changed in this diff Show More