merge: sync upstream/main before PR

This commit is contained in:
Wang Lvyuan 2026-03-19 16:37:28 +08:00
commit 1de18b89dd
107 changed files with 2973 additions and 341 deletions

View File

@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthClient := repository.NewOpenAIOAuthClient()
@ -143,6 +142,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient) rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService() dataManagementService := service.NewDataManagementService()

View File

@ -716,6 +716,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100}, {Name: "model", Type: field.TypeString, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@ -755,31 +756,31 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "usage_logs_api_keys_usage_logs", Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{APIKeysColumns[0]}, RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_accounts_usage_logs", Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{AccountsColumns[0]}, RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_groups_usage_logs", Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
{ {
Symbol: "usage_logs_users_usage_logs", Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UsersColumns[0]}, RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction, OnDelete: schema.NoAction,
}, },
{ {
Symbol: "usage_logs_user_subscriptions_usage_logs", Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull, OnDelete: schema.SetNull,
}, },
@ -788,32 +789,32 @@ var (
{ {
Name: "usagelog_user_id", Name: "usagelog_user_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]}, Columns: []*schema.Column{UsageLogsColumns[32]},
}, },
{ {
Name: "usagelog_api_key_id", Name: "usagelog_api_key_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]}, Columns: []*schema.Column{UsageLogsColumns[29]},
}, },
{ {
Name: "usagelog_account_id", Name: "usagelog_account_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]}, Columns: []*schema.Column{UsageLogsColumns[30]},
}, },
{ {
Name: "usagelog_group_id", Name: "usagelog_group_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]}, Columns: []*schema.Column{UsageLogsColumns[31]},
}, },
{ {
Name: "usagelog_subscription_id", Name: "usagelog_subscription_id",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]}, Columns: []*schema.Column{UsageLogsColumns[33]},
}, },
{ {
Name: "usagelog_created_at", Name: "usagelog_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_model", Name: "usagelog_model",
@ -828,17 +829,17 @@ var (
{ {
Name: "usagelog_user_id_created_at", Name: "usagelog_user_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_api_key_id_created_at", Name: "usagelog_api_key_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
}, },
{ {
Name: "usagelog_group_id_created_at", Name: "usagelog_group_id_created_at",
Unique: false, Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
}, },
}, },
} }

View File

@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
id *int64 id *int64
request_id *string request_id *string
model *string model *string
upstream_model *string
input_tokens *int input_tokens *int
addinput_tokens *int addinput_tokens *int
output_tokens *int output_tokens *int
@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
m.model = nil m.model = nil
} }
// SetUpstreamModel sets the "upstream_model" field.
func (m *UsageLogMutation) SetUpstreamModel(s string) {
m.upstream_model = &s
}
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
v := m.upstream_model
if v == nil {
return
}
return *v, true
}
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
}
return oldValue.UpstreamModel, nil
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (m *UsageLogMutation) ClearUpstreamModel() {
m.upstream_model = nil
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
}
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
func (m *UsageLogMutation) UpstreamModelCleared() bool {
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
return ok
}
// ResetUpstreamModel resets all changes to the "upstream_model" field.
func (m *UsageLogMutation) ResetUpstreamModel() {
m.upstream_model = nil
delete(m.clearedFields, usagelog.FieldUpstreamModel)
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) { func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i m.group = &i
@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UsageLogMutation) Fields() []string { func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 32) fields := make([]string, 0, 33)
if m.user != nil { if m.user != nil {
fields = append(fields, usagelog.FieldUserID) fields = append(fields, usagelog.FieldUserID)
} }
@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.model != nil { if m.model != nil {
fields = append(fields, usagelog.FieldModel) fields = append(fields, usagelog.FieldModel)
} }
if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.group != nil { if m.group != nil {
fields = append(fields, usagelog.FieldGroupID) fields = append(fields, usagelog.FieldGroupID)
} }
@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestID() return m.RequestID()
case usagelog.FieldModel: case usagelog.FieldModel:
return m.Model() return m.Model()
case usagelog.FieldUpstreamModel:
return m.UpstreamModel()
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
return m.GroupID() return m.GroupID()
case usagelog.FieldSubscriptionID: case usagelog.FieldSubscriptionID:
@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestID(ctx) return m.OldRequestID(ctx)
case usagelog.FieldModel: case usagelog.FieldModel:
return m.OldModel(ctx) return m.OldModel(ctx)
case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx)
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
return m.OldGroupID(ctx) return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID: case usagelog.FieldSubscriptionID:
@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
} }
m.SetModel(v) m.SetModel(v)
return nil return nil
case usagelog.FieldUpstreamModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUpstreamModel(v)
return nil
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
v, ok := value.(int64) v, ok := value.(int64)
if !ok { if !ok {
@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
// mutation. // mutation.
func (m *UsageLogMutation) ClearedFields() []string { func (m *UsageLogMutation) ClearedFields() []string {
var fields []string var fields []string
if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.FieldCleared(usagelog.FieldGroupID) { if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID) fields = append(fields, usagelog.FieldGroupID)
} }
@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema. // error if the field is not defined in the schema.
func (m *UsageLogMutation) ClearField(name string) error { func (m *UsageLogMutation) ClearField(name string) error {
switch name { switch name {
case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel()
return nil
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
m.ClearGroupID() m.ClearGroupID()
return nil return nil
@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldModel: case usagelog.FieldModel:
m.ResetModel() m.ResetModel()
return nil return nil
case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel()
return nil
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
m.ResetGroupID() m.ResetGroupID()
return nil return nil

View File

@ -821,92 +821,96 @@ func init() {
return nil return nil
} }
}() }()
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field. // usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[7].Descriptor() usagelogDescInputTokens := usagelogFields[8].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field. // usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[8].Descriptor() usagelogDescOutputTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field. // usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[13].Descriptor() usagelogDescInputCost := usagelogFields[14].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field. // usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field. // usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[14].Descriptor() usagelogDescOutputCost := usagelogFields[15].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[16].Descriptor() usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field. // usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[17].Descriptor() usagelogDescTotalCost := usagelogFields[18].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field. // usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[18].Descriptor() usagelogDescActualCost := usagelogFields[19].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[19].Descriptor() usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field. // usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[21].Descriptor() usagelogDescBillingType := usagelogFields[22].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field. // usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field. // usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[22].Descriptor() usagelogDescStream := usagelogFields[23].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field. // usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool) usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field. // usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[25].Descriptor() usagelogDescUserAgent := usagelogFields[26].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field. // usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[26].Descriptor() usagelogDescIPAddress := usagelogFields[27].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field. // usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[27].Descriptor() usagelogDescImageCount := usagelogFields[28].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field. // usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field. // usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[28].Descriptor() usagelogDescImageSize := usagelogFields[29].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field. // usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor() usagelogDescMediaType := usagelogFields[30].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field. // usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[31].Descriptor() usagelogDescCreatedAt := usagelogFields[32].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin() userMixin := schema.User{}.Mixin()

View File

@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
field.String("model"). field.String("model").
MaxLen(100). MaxLen(100).
NotEmpty(), NotEmpty(),
// UpstreamModel stores the actual upstream model name when model mapping
// is applied. NULL means no mapping — the requested model was used as-is.
field.String("upstream_model").
MaxLen(100).
Optional().
Nillable(),
field.Int64("group_id"). field.Int64("group_id").
Optional(). Optional().
Nillable(), Nillable(),

View File

@ -32,6 +32,8 @@ type UsageLog struct {
RequestID string `json:"request_id,omitempty"` RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field. // Model holds the value of the "model" field.
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"`
// GroupID holds the value of the "group_id" field. // GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field. // SubscriptionID holds the value of the "subscription_id" field.
@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt: case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Model = value.String _m.Model = value.String
} }
case usagelog.FieldUpstreamModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
} else if value.Valid {
_m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String
}
case usagelog.FieldGroupID: case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok { if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i]) return fmt.Errorf("unexpected type %T for field group_id", values[i])
@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("model=") builder.WriteString("model=")
builder.WriteString(_m.Model) builder.WriteString(_m.Model)
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.UpstreamModel; v != nil {
builder.WriteString("upstream_model=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil { if v := _m.GroupID; v != nil {
builder.WriteString("group_id=") builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v)) builder.WriteString(fmt.Sprintf("%v", *v))

View File

@ -24,6 +24,8 @@ const (
FieldRequestID = "request_id" FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database. // FieldModel holds the string denoting the model field in the database.
FieldModel = "model" FieldModel = "model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model"
// FieldGroupID holds the string denoting the group_id field in the database. // FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id" FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database. // FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@ -135,6 +137,7 @@ var Columns = []string{
FieldAccountID, FieldAccountID,
FieldRequestID, FieldRequestID,
FieldModel, FieldModel,
FieldUpstreamModel,
FieldGroupID, FieldGroupID,
FieldSubscriptionID, FieldSubscriptionID,
FieldInputTokens, FieldInputTokens,
@ -179,6 +182,8 @@ var (
RequestIDValidator func(string) error RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save. // ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error ModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field. // DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field. // DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc() return sql.OrderByField(FieldModel, opts...).ToFunc()
} }
// ByUpstreamModel orders the results by the upstream_model field.
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field. // ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption { func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc() return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
} }
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog { func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
} }
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
func UpstreamModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
func UpstreamModelNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
}
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
func UpstreamModelIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
}
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
}
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
func UpstreamModelGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
}
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
func UpstreamModelGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
}
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
func UpstreamModelLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
}
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
func UpstreamModelLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
}
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
func UpstreamModelContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
}
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
}
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
}
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
func UpstreamModelIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
}
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
func UpstreamModelNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
}
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
func UpstreamModelEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
}
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field. // GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog { func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))

View File

@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
return _c return _c
} }
// SetUpstreamModel sets the "upstream_model" field.
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
_c.mutation.SetUpstreamModel(v)
return _c
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
if v != nil {
_c.SetUpstreamModel(*v)
}
return _c
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v) _c.mutation.SetGroupID(v)
@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _c.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if _, ok := _c.mutation.InputTokens(); !ok { if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
} }
@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
_node.Model = value _node.Model = value
} }
if value, ok := _c.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value
}
if value, ok := _c.mutation.InputTokens(); ok { if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value _node.InputTokens = value
@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
return u return u
} }
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldUpstreamModel, v)
return u
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldUpstreamModel)
return u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
u.SetNull(usagelog.FieldUpstreamModel)
return u
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v) u.Set(usagelog.FieldGroupID, v)
@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
}) })
} }
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {
@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
}) })
} }
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) { return u.Update(func(s *UsageLogUpsert) {

View File

@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
return _u return _u
} }
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v) _u.mutation.SetGroupID(v)
@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok { if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil { if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Model(); ok { if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
} }
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok { if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
} }
@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
return _u return _u
} }
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field. // SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v) _u.mutation.SetGroupID(v)
@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
} }
} }
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok { if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil { if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.Model(); ok { if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value) _spec.SetField(usagelog.FieldModel, field.TypeString, value)
} }
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok { if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
} }

View File

@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
@ -98,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@ -238,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -273,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -326,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型 "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5", "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
// Claude Haiku → Sonnet无 Haiku 支持) // Claude Haiku → Sonnet无 Haiku 支持)
"claude-haiku-4-5": "claude-sonnet-4-5", "claude-haiku-4-5": "claude-sonnet-4-6",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5", "claude-haiku-4-5-20251001": "claude-sonnet-4-6",
// Gemini 2.5 白名单 // Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash", "gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image", "gemini-2.5-flash-image": "gemini-2.5-flash-image",

View File

@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
adminSvc := newStubAdminService() adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil) userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc) groupHandler := NewGroupHandler(adminSvc, nil, nil)
proxyHandler := NewProxyHandler(adminSvc) proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil) redeemHandler := NewRedeemHandler(adminSvc, nil)

View File

@ -273,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
// Parse optional filter params // Parse optional filter params
var userID, apiKeyID, accountID, groupID int64 var userID, apiKeyID, accountID, groupID int64
modelSource := usagestats.ModelSourceRequested
var requestType *int16 var requestType *int16
var stream *bool var stream *bool
var billingType *int8 var billingType *int8
@ -297,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id groupID = id
} }
} }
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
modelSource = rawModelSource
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr) parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil { if err != nil {
@ -323,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
} }
} }
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get model statistics") response.Error(c, 500, "Failed to get model statistics")
return return
@ -619,6 +627,12 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
} }
} }
dim.Model = c.Query("model") dim.Model = c.Query("model")
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
dim.ModelType = rawModelSource
dim.Endpoint = c.Query("endpoint") dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")

View File

@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code) require.Equal(t, http.StatusBadRequest, rec.Code)
} }
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsValidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) { func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{ repo := &dashboardUsageRepoCapture{

View File

@ -73,9 +73,35 @@ func TestGetUserBreakdown_ModelFilter(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model) require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
require.Equal(t, int64(0), repo.capturedDim.GroupID) require.Equal(t, int64(0), repo.capturedDim.GroupID)
} }
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
}
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestGetUserBreakdown_EndpointFilter(t *testing.T) { func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{} repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo) router := newUserBreakdownRouter(repo)

View File

@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
APIKeyID int64 `json:"api_key_id"` APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
ModelSource string `json:"model_source,omitempty"`
RequestType *int16 `json:"request_type"` RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"` Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"` BillingType *int8 `json:"billing_type"`
@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
ctx context.Context, ctx context.Context,
startTime, endTime time.Time, startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64, userID, apiKeyID, accountID, groupID int64,
modelSource string,
requestType *int16, requestType *int16,
stream *bool, stream *bool,
billingType *int8, billingType *int8,
@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
APIKeyID: apiKeyID, APIKeyID: apiKeyID,
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
ModelSource: usagestats.NormalizeModelSource(modelSource),
RequestType: requestType, RequestType: requestType,
Stream: stream, Stream: stream,
BillingType: billingType, BillingType: billingType,
}) })
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
}) })
if err != nil { if err != nil {
return nil, hit, err return nil, hit, err

View File

@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
filters.APIKeyID, filters.APIKeyID,
filters.AccountID, filters.AccountID,
filters.GroupID, filters.GroupID,
usagestats.ModelSourceRequested,
filters.RequestType, filters.RequestType,
filters.Stream, filters.Stream,
filters.BillingType, filters.BillingType,

View File

@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -16,7 +17,9 @@ import (
// GroupHandler handles admin group management // GroupHandler handles admin group management
type GroupHandler struct { type GroupHandler struct {
adminService service.AdminService adminService service.AdminService
dashboardService *service.DashboardService
groupCapacityService *service.GroupCapacityService
} }
type optionalLimitField struct { type optionalLimitField struct {
@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 {
} }
// NewGroupHandler creates a new admin group handler // NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler { func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
return &GroupHandler{ return &GroupHandler{
adminService: adminService, adminService: adminService,
dashboardService: dashboardService,
groupCapacityService: groupCapacityService,
} }
} }
@ -363,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
_ = groupID // TODO: implement actual stats _ = groupID // TODO: implement actual stats
} }
// GetUsageSummary returns today's and cumulative cost for all groups.
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
userTZ := c.Query("timezone")
now := timezone.NowInUserLocation(userTZ)
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
if err != nil {
response.Error(c, 500, "Failed to get group usage summary")
return
}
response.Success(c, results)
}
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
// GET /api/v1/admin/groups/capacity-summary
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get group capacity summary")
return
}
response.Success(c, results)
}
// GetGroupAPIKeys handles getting API keys in a group // GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys // GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {

View File

@ -977,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response.Success(c, gin.H{"message": "Admin API key deleted"}) response.Success(c, gin.H{"message": "Admin API key deleted"})
} }
// GetOverloadCooldownSettings 获取529过载冷却配置
// GET /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: settings.Enabled,
CooldownMinutes: settings.CooldownMinutes,
})
}
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
type UpdateOverloadCooldownSettingsRequest struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// UpdateOverloadCooldownSettings 更新529过载冷却配置
// PUT /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
var req UpdateOverloadCooldownSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.OverloadCooldownSettings{
Enabled: req.Enabled,
CooldownMinutes: req.CooldownMinutes,
}
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: updatedSettings.Enabled,
CooldownMinutes: updatedSettings.CooldownMinutes,
})
}
// GetStreamTimeoutSettings 获取流超时处理配置 // GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout // GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {

View File

@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
} }
} }
status := c.Query("status") status := c.Query("status")
platform := c.Query("platform")
// Parse sorting parameters // Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at") sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc") sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder) subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil return nil
} }
out := &AdminGroup{ out := &AdminGroup{
Group: groupFromServiceBase(g), Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting, ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled, ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject, MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes, SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount, AccountCount: g.AccountCount,
SortOrder: g.SortOrder, ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
} }
if len(g.AccountGroups) > 0 { if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: l.Model,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier, ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort, ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint, InboundEndpoint: l.InboundEndpoint,

View File

@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
Items []SoraS3Profile `json:"items"` Items []SoraS3Profile `json:"items"`
} }
// OverloadCooldownSettings 529过载冷却配置 DTO
type OverloadCooldownSettings struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// StreamTimeoutSettings 流超时处理配置 DTO // StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct { type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`

View File

@ -122,9 +122,11 @@ type AdminGroup struct {
DefaultMappedModel string `json:"default_mapped_model"` DefaultMappedModel string `json:"default_mapped_model"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"` AccountCount int64 `json:"account_count,omitempty"`
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
// 分组排序 // 分组排序
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
@ -332,6 +334,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"` ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.

View File

@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
return nil, nil return nil, nil
} }
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil return 0, nil
} }

View File

@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
return []byte(`{ return []byte(`{
"model":"claude-3-5-sonnet-20241022", "model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
}`) }`)
} }
@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
System: []any{ System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
}, },
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
} }
// body 非法 JSON如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 // body 非法 JSON如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
"system": []any{ "system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
}, },
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
}) })
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)

View File

@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil
@ -348,6 +348,9 @@ func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil return nil, nil
} }
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil return nil, nil
} }

View File

@ -3,6 +3,28 @@ package usagestats
import "time" import "time"
const (
ModelSourceRequested = "requested"
ModelSourceUpstream = "upstream"
ModelSourceMapping = "mapping"
)
func IsValidModelSource(source string) bool {
switch source {
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
return true
default:
return false
}
}
func NormalizeModelSource(source string) string {
if IsValidModelSource(source) {
return source
}
return ModelSourceRequested
}
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats struct { type DashboardStats struct {
// 用户统计 // 用户统计
@ -90,6 +112,13 @@ type EndpointStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除 ActualCost float64 `json:"actual_cost"` // 实际扣除
} }
// GroupUsageSummary represents today's and cumulative cost for a single group.
type GroupUsageSummary struct {
GroupID int64 `json:"group_id"`
TodayCost float64 `json:"today_cost"`
TotalCost float64 `json:"total_cost"`
}
// GroupStat represents usage statistics for a single group // GroupStat represents usage statistics for a single group
type GroupStat struct { type GroupStat struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
@ -143,6 +172,7 @@ type UserBreakdownItem struct {
type UserBreakdownDimension struct { type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable) GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable) Model string // filter by model name (non-empty to enable)
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable) Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path" EndpointType string // "inbound", "upstream", or "path"
} }

View File

@ -0,0 +1,47 @@
package usagestats
import "testing"
func TestIsValidModelSource(t *testing.T) {
tests := []struct {
name string
source string
want bool
}{
{name: "requested", source: ModelSourceRequested, want: true},
{name: "upstream", source: ModelSourceUpstream, want: true},
{name: "mapping", source: ModelSourceMapping, want: true},
{name: "invalid", source: "foobar", want: false},
{name: "empty", source: "", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsValidModelSource(tc.source); got != tc.want {
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
}
})
}
}
func TestNormalizeModelSource(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
{name: "empty falls back", source: "", want: ModelSourceRequested},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeModelSource(tc.source); got != tc.want {
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
}
})
}
}

View File

@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
if err != nil { if err != nil {
return nil, err return nil, err
} }
count, _ := r.GetAccountCount(ctx, out.ID) total, active, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count out.AccountCount = total
out.ActiveAccountCount = active
return out, nil return out, nil
} }
@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
counts, err := r.loadAccountCounts(ctx, groupIDs) counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil { if err == nil {
for i := range outGroups { for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID] c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
} }
} }
@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
counts, err := r.loadAccountCounts(ctx, groupIDs) counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil { if err == nil {
for i := range outGroups { for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID] c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
} }
} }
@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
counts, err := r.loadAccountCounts(ctx, groupIDs) counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil { if err == nil {
for i := range outGroups { for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID] c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
} }
} }
@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
return result, nil return result, nil
} }
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
var count int64 var rateLimited int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { err = scanSingleRow(ctx, r.sql,
return 0, err `SELECT COUNT(*),
} COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
return count, nil COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
))
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = $1`,
[]any{groupID}, &total, &active, &rateLimited)
return
} }
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil return affectedUserIDs, nil
} }
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) { type groupAccountCounts struct {
counts = make(map[int64]int64, len(groupIDs)) Total int64
Active int64
RateLimited int64
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
counts = make(map[int64]groupAccountCounts, len(groupIDs))
if len(groupIDs) == 0 { if len(groupIDs) == 0 {
return counts, nil return counts, nil
} }
rows, err := r.sql.QueryContext( rows, err := r.sql.QueryContext(
ctx, ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id", `SELECT ag.group_id,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)) AS rate_limited
FROM account_groups ag
JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = ANY($1)
GROUP BY ag.group_id`,
pq.Array(groupIDs), pq.Array(groupIDs),
) )
if err != nil { if err != nil {
@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
for rows.Next() { for rows.Next() {
var groupID int64 var groupID int64
var count int64 var c groupAccountCounts
if err = rows.Scan(&groupID, &count); err != nil { if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
return nil, err return nil, err
} }
counts[groupID] = count counts[groupID] = c
} }
if err = rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err

View File

@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2) _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err) s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount") s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count) s.Require().Equal(int64(2), count)
} }
@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
} }
s.Require().NoError(s.repo.Create(s.ctx, group)) s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Zero(count) s.Require().Zero(count)
} }
@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
s.Require().NoError(err, "DeleteAccountGroupsByGroupID") s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row") s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID) count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount") s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups") s.Require().Equal(int64(0), count, "expected 0 account groups")
} }
@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(int64(3), affected) s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID) count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count) s.Require().Zero(count)
} }

View File

@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint",
@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
"bigint", "bigint",
"text", "text",
"text", "text",
"text",
"bigint", "bigint",
"bigint", "bigint",
"integer", "integer",
@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5, $6,
$6, $7, $7, $8,
$8, $9, $10, $11, $9, $10, $11, $12,
$12, $13, $13, $14,
$14, $15, $16, $17, $18, $19, $15, $16, $17, $18, $19, $20,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*38) args := make([]any, 0, len(keys)*39)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*38) args := make([]any, 0, len(preparedList)*39)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id, account_id,
request_id, request_id,
model, model,
upstream_model,
group_id, group_id,
subscription_id, subscription_id,
input_tokens, input_tokens,
@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden, cache_ttl_overridden,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5, $6,
$6, $7, $7, $8,
$8, $9, $10, $11, $9, $10, $11, $12,
$12, $13, $13, $14,
$14, $15, $16, $17, $18, $19, $15, $16, $17, $18, $19, $20,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID, log.AccountID,
requestIDArg, requestIDArg,
log.Model, log.Model,
upstreamModel,
groupID, groupID,
subscriptionID, subscriptionID,
log.InputTokens, log.InputTokens,
@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
// GetModelStatsWithFilters returns model statistics with optional filters // GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
}
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
// source: requested | upstream | mapping.
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
}
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier // 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
} }
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(` query := fmt.Sprintf(`
SELECT SELECT
model, %s as model,
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens,
@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
%s %s
FROM usage_logs FROM usage_logs
WHERE created_at >= $1 AND created_at < $2 WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr) `, modelExpr, actualCostExpr)
args := []any{startTime, endTime} args := []any{startTime, endTime}
if userID > 0 { if userID > 0 {
@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType)) args = append(args, int16(*billingType))
} }
query += " GROUP BY model ORDER BY total_tokens DESC" query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
args = append(args, dim.GroupID) args = append(args, dim.GroupID)
} }
if dim.Model != "" { if dim.Model != "" {
query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1) query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
args = append(args, dim.Model) args = append(args, dim.Model)
} }
if dim.Endpoint != "" { if dim.Endpoint != "" {
@ -3067,6 +3089,53 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
return results, nil return results, nil
} }
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
// todayStart is the start-of-day in the caller's timezone (UTC-based).
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
// or a materialized view / pre-aggregation table for cumulative costs.
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
query := `
SELECT
g.id AS group_id,
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
FROM groups g
LEFT JOIN usage_logs ul ON ul.group_id = g.id
GROUP BY g.id
`
rows, err := r.sql.QueryContext(ctx, query, todayStart)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []usagestats.GroupUsageSummary
for rows.Next() {
var row usagestats.GroupUsageSummary
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
default:
return "model"
}
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name. // resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string { func resolveEndpointColumn(endpointType string) string {
switch endpointType { switch endpointType {
@ -3819,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64 accountID int64
requestID sql.NullString requestID sql.NullString
model string model string
upstreamModel sql.NullString
groupID sql.NullInt64 groupID sql.NullInt64
subscriptionID sql.NullInt64 subscriptionID sql.NullInt64
inputTokens int inputTokens int
@ -3861,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID, &accountID,
&requestID, &requestID,
&model, &model,
&upstreamModel,
&groupID, &groupID,
&subscriptionID, &subscriptionID,
&inputTokens, &inputTokens,
@ -3973,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamEndpoint.Valid { if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String log.UpstreamEndpoint = &upstreamEndpoint.String
} }
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
return log, nil return log, nil
} }

View File

@ -5,6 +5,7 @@ package repository
import ( import (
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) {
{"inbound", "ul.inbound_endpoint"}, {"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"}, {"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default {"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback {"unknown", "ul.inbound_endpoint"}, // fallback
} }
for _, tc := range tests { for _, tc := range tests {
@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) {
}) })
} }
} }
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}

View File

@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID, log.AccountID,
log.RequestID, log.RequestID,
log.Model, log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id sqlmock.AnyArg(), // subscription_id
log.InputTokens, log.InputTokens,
@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model, log.Model,
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens, log.InputTokens,
log.OutputTokens, log.OutputTokens,
log.CacheCreationTokens, log.CacheCreationTokens,
@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"}, sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model "gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id sql.NullInt64{}, // subscription_id
1, // input_tokens 1, // input_tokens
@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31), int64(31),
sql.NullString{Valid: true, String: "req-2"}, sql.NullString{Valid: true, String: "req-2"},
"gpt-5", "gpt-5",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,
@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32), int64(32),
sql.NullString{Valid: true, String: "req-3"}, sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4", "gpt-5.4",
sql.NullString{},
sql.NullInt64{}, sql.NullInt64{},
sql.NullInt64{}, sql.NullInt64{},
1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6,

View File

@ -5,6 +5,7 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
} }
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query() q := client.UserSubscription.Query()
if userID != nil { if userID != nil {
@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil { if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID)) q = q.Where(usersubscription.GroupIDEQ(*groupID))
} }
if platform != "" {
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
}
// Status filtering with real-time expiration check // Status filtering with real-time expiration check
now := time.Now() now := time.Now()

View File

@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list") group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil) s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "") subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID) s.Require().Equal(user1.ID, subs[0].UserID)
@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID) s.Require().Equal(g1.ID, subs[0].GroupID)
@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "") subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, errors.New("not implemented") return 0, 0, errors.New("not implemented")
} }
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
@ -1786,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct { type stubSettingRepo struct {
all map[string]string all map[string]string

View File

@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }

View File

@ -227,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
groups.GET("", h.Admin.Group.List) groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll) groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
groups.GET("/:id", h.Admin.Group.GetByID) groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create) groups.POST("", h.Admin.Group.Create)
@ -400,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey) adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 529过载冷却配置
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
// 流超时处理配置 // 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)

View File

@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
return normalized, nil return normalized, nil
} }
// generateSessionString generates a Claude Code style session string // generateSessionString generates a Claude Code style session string.
// The output format is determined by the UA version in claude.DefaultHeaders,
// ensuring consistency between the user_id format and the UA sent to upstream.
func generateSessionString() (string, error) { func generateSessionString() (string, error) {
bytes := make([]byte, 32) b := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(b); err != nil {
return "", err return "", err
} }
hex64 := hex.EncodeToString(bytes) hex64 := hex.EncodeToString(b)
sessionUUID := uuid.New().String() sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
} }
// createTestPayload creates a Claude Code style test request payload // createTestPayload creates a Claude Code style test request payload

View File

@ -49,6 +49,7 @@ type UsageLogRepository interface {
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)

View File

@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected") panic("unexpected")
} }
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected") panic("unexpected")
} }
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {

View File

@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }

View File

@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }

View File

@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-6-thinking", expected: "claude-opus-4-6-thinking",
}, },
{ {
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5", name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5", requestedModel: "claude-haiku-4-5",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-6",
}, },
{ {
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5-20251001", requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-6",
}, },
{ {
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5", name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",

View File

@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则 // 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5 systemPromptThreshold = 0.5
) )
@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false return false
} }
if !userIDPattern.MatchString(userID) { if ParseMetadataUserID(userID) == nil {
return false return false
} }
@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 // ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 // 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) return ExtractCLIVersion(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
} }
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 // SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中

View File

@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil { if err != nil {
@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx) data, err := s.cache.GetDashboardStats(ctx)
if err != nil { if err != nil {

View File

@ -170,6 +170,13 @@ const (
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
// =========================
// Overload Cooldown (529)
// =========================
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
// ========================= // =========================
// Stream Timeout Handling // Stream Timeout Handling
// ========================= // =========================

View File

@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
rateLimitService: &RateLimitService{}, rateLimitService: &RateLimitService{},
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens) require.Equal(t, 12, result.Usage.InputTokens)
@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
} }
svc := &GatewayService{} svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token") require.Contains(t, err.Error(), "requires apikey token")
@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
} }
account := newAnthropicAPIKeyAccountForTest() account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed") require.Contains(t, err.Error(), "upstream request failed")
@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
httpUpstream: upstream, httpUpstream: upstream,
} }
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result) require.Nil(t, result)
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "empty response") require.Contains(t, err.Error(), "empty response")

View File

@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil

View File

@ -28,6 +28,12 @@ var (
patternEmptyContentSpaced = []byte(`"content": []`) patternEmptyContentSpaced = []byte(`"content": []`)
patternEmptyContentSp1 = []byte(`"content" : []`) patternEmptyContentSp1 = []byte(`"content" : []`)
patternEmptyContentSp2 = []byte(`"content" :[]`) patternEmptyContentSp2 = []byte(`"content" :[]`)
// Fast-path patterns for empty text blocks: {"type":"text","text":""}
patternEmptyText = []byte(`"text":""`)
patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`)
) )
// SessionContext 粘性会话上下文,用于区分不同来源的请求。 // SessionContext 粘性会话上下文,用于区分不同来源的请求。
@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternThinkingField) || bytes.Contains(body, patternThinkingField) ||
bytes.Contains(body, patternThinkingFieldSpaced) bytes.Contains(body, patternThinkingFieldSpaced)
// Also check for empty content arrays that need fixing. // Also check for empty content arrays and empty text blocks that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below. // Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, patternEmptyContent) || hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
bytes.Contains(body, patternEmptyContentSpaced) || bytes.Contains(body, patternEmptyContentSpaced) ||
bytes.Contains(body, patternEmptyContentSp1) || bytes.Contains(body, patternEmptyContentSp1) ||
bytes.Contains(body, patternEmptyContentSp2) bytes.Contains(body, patternEmptyContentSp2)
// Check for empty text blocks: {"type":"text","text":""}
// These cause upstream 400: "text content blocks must be non-empty"
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
// Fast path: nothing to process // Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent { if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock {
return body return body
} }
@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternTypeRedactedThinking) || bytes.Contains(body, patternTypeRedactedThinking) ||
bytes.Contains(body, patternTypeRedactedSpaced) || bytes.Contains(body, patternTypeRedactedSpaced) ||
bytes.Contains(body, patternThinkingFieldSpaced) bytes.Contains(body, patternThinkingFieldSpaced)
if !hasEmptyContent && !containsThinkingBlocks { if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks {
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
out = removeThinkingDependentContextStrategies(out) out = removeThinkingDependentContextStrategies(out)
@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType, _ := blockMap["type"].(string) blockType, _ := blockMap["type"].(string)
// Strip empty text blocks: {"type":"text","text":""}
// Upstream rejects these with 400: "text content blocks must be non-empty"
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
modifiedThisMsg = true
ensureNewContent(bi)
continue
}
}
// Convert thinking blocks to text (preserve content) and drop redacted_thinking. // Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType { switch blockType {
case "thinking": case "thinking":

View File

@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T)
require.NotEmpty(t, content0["text"]) require.NotEmpty(t, content0["text"])
} }
func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
// Empty text blocks cause upstream 400: "text content blocks must be non-empty"
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]},
{"role":"assistant","content":[{"type":"text","text":""}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
// First message: empty text block stripped, "hello" preserved
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
require.Equal(t, "hello", content0[0].(map[string]any)["text"])
// Second message: only had empty text block → gets placeholder
msg1 := msgs[1].(map[string]any)
content1 := msg1["content"].([]any)
require.Len(t, content1, 1)
block1 := content1[0].(map[string]any)
require.Equal(t, "text", block1["type"])
require.NotEmpty(t, block1["text"])
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
// Fast path: no thinking content, no empty content, no empty text blocks → unchanged
require.Equal(t, input, out)
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{ input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024}, "thinking":{"type":"enabled","budget_tokens":1024},

View File

@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var ( var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
@ -491,6 +490,7 @@ type ForwardResult struct {
RequestID string RequestID string
Usage ClaudeUsage Usage ClaudeUsage
Model string Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
@ -644,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx // 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" { if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
return match[1] return uid.SessionID
} }
} }
@ -1026,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
sessionID = generateSessionUUID(seed) sessionID = generateSessionUUID(seed)
} }
// Prefer the newer format that includes account_uuid (if present), // 根据指纹 UA 版本选择输出格式
// otherwise fall back to the legacy Claude Code format. var uaVersion string
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) if fp != nil {
if accountUUID != "" { uaVersion = ExtractCLIVersion(fp.UserAgent)
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
} }
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
} }
// GenerateSessionUUID creates a deterministic UUID4 from a seed string. // GenerateSessionUUID creates a deterministic UUID4 from a seed string.
@ -3989,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
passthroughModel = mappedModel passthroughModel = mappedModel
} }
} }
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: passthroughBody,
RequestModel: passthroughModel,
OriginalModel: parsed.Model,
RequestStream: parsed.Stream,
StartTime: startTime,
})
} }
if account != nil && account.IsBedrock() { if account != nil && account.IsBedrock() {
@ -4513,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: originalModel, // 使用原始模型用于计费和日志
UpstreamModel: mappedModel,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
@ -4520,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil }, nil
} }
type anthropicPassthroughForwardInput struct {
Body []byte
RequestModel string
OriginalModel string
RequestStream bool
StartTime time.Time
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
ctx context.Context, ctx context.Context,
c *gin.Context, c *gin.Context,
account *Account, account *Account,
body []byte, body []byte,
reqModel string, reqModel string,
originalModel string,
reqStream bool, reqStream bool,
startTime time.Time, startTime time.Time,
) (*ForwardResult, error) {
return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{
Body: body,
RequestModel: reqModel,
OriginalModel: originalModel,
RequestStream: reqStream,
StartTime: startTime,
})
}
func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
ctx context.Context,
c *gin.Context,
account *Account,
input anthropicPassthroughForwardInput,
) (*ForwardResult, error) { ) (*ForwardResult, error) {
token, tokenType, err := s.GetAccessToken(ctx, account) token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil { if err != nil {
@ -4543,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
} }
logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v",
account.ID, account.Name, reqModel, reqStream) account.ID, account.Name, input.RequestModel, input.RequestStream)
if c != nil { if c != nil {
c.Set("anthropic_passthrough", true) c.Set("anthropic_passthrough", true)
} }
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body) setOpsUpstreamRequestBody(c, input.Body)
var resp *http.Response var resp *http.Response
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream)
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
@ -4713,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
var clientDisconnect bool var clientDisconnect bool
if reqStream { if input.RequestStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -4734,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
return &ForwardResult{ return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: input.OriginalModel,
Stream: reqStream, UpstreamModel: input.RequestModel,
Duration: time.Since(startTime), Stream: input.RequestStream,
Duration: time.Since(input.StartTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect, ClientDisconnect: clientDisconnect,
}, nil }, nil
@ -5241,6 +5273,7 @@ func (s *GatewayService) forwardBedrock(
RequestID: resp.Header.Get("x-amzn-requestid"), RequestID: resp.Header.Get("x-amzn-requestid"),
Usage: *usage, Usage: *usage,
Model: reqModel, Model: reqModel,
UpstreamModel: mappedModel,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
@ -5533,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值 // 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
@ -6068,9 +6101,11 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return true return true
} }
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的 // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block
// 例如: "all messages must have non-empty content" // 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { // "messages: text content blocks must be non-empty"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") ||
strings.Contains(msg, "content blocks must be non-empty") {
logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error")
return true return true
} }
@ -7529,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
@ -7710,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
@ -8161,7 +8198,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if err == nil { if err == nil {
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }

View File

@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil return false, nil
} }
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil return 0, nil

View File

@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
parsed := &ParsedRequest{ parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
HasSystem: true, HasSystem: true,
Messages: []any{ Messages: []any{
@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}
parsed := &ParsedRequest{ parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{ Messages: []any{
map[string]any{"role": "user", "content": "hello"}, map[string]any{"role": "user", "content": "hello"},
}, },
@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
"metadata session_id should take priority over SessionContext") "metadata session_id should take priority over SessionContext")
} }
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{} svc := &GatewayService{}

View File

@ -64,8 +64,10 @@ type Group struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
AccountGroups []AccountGroup AccountGroups []AccountGroup
AccountCount int64 AccountCount int64
ActiveAccountCount int64
RateLimitedAccountCount int64
} }
func (g *Group) IsActive() bool { func (g *Group) IsActive() bool {

View File

@ -0,0 +1,131 @@
package service
import (
"context"
"time"
)
// GroupCapacitySummary holds aggregated capacity for a single group.
type GroupCapacitySummary struct {
GroupID int64 `json:"group_id"`
ConcurrencyUsed int `json:"concurrency_used"`
ConcurrencyMax int `json:"concurrency_max"`
SessionsUsed int `json:"sessions_used"`
SessionsMax int `json:"sessions_max"`
RPMUsed int `json:"rpm_used"`
RPMMax int `json:"rpm_max"`
}
// GroupCapacityService aggregates per-group capacity from runtime data.
type GroupCapacityService struct {
accountRepo AccountRepository
groupRepo GroupRepository
concurrencyService *ConcurrencyService
sessionLimitCache SessionLimitCache
rpmCache RPMCache
}
// NewGroupCapacityService creates a new GroupCapacityService.
func NewGroupCapacityService(
accountRepo AccountRepository,
groupRepo GroupRepository,
concurrencyService *ConcurrencyService,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
) *GroupCapacityService {
return &GroupCapacityService{
accountRepo: accountRepo,
groupRepo: groupRepo,
concurrencyService: concurrencyService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
}
}
// GetAllGroupCapacity returns capacity summary for all active groups.
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, err
}
results := make([]GroupCapacitySummary, 0, len(groups))
for i := range groups {
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
if err != nil {
// Skip groups with errors, return partial results
continue
}
cap.GroupID = groups[i].ID
results = append(results, cap)
}
return results, nil
}
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
if err != nil {
return GroupCapacitySummary{}, err
}
if len(accounts) == 0 {
return GroupCapacitySummary{}, nil
}
// Collect account IDs and config values
accountIDs := make([]int64, 0, len(accounts))
sessionTimeouts := make(map[int64]time.Duration)
var concurrencyMax, sessionsMax, rpmMax int
for i := range accounts {
acc := &accounts[i]
accountIDs = append(accountIDs, acc.ID)
concurrencyMax += acc.Concurrency
if ms := acc.GetMaxSessions(); ms > 0 {
sessionsMax += ms
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
if timeout <= 0 {
timeout = 5 * time.Minute
}
sessionTimeouts[acc.ID] = timeout
}
if rpm := acc.GetBaseRPM(); rpm > 0 {
rpmMax += rpm
}
}
// Batch query runtime data from Redis
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
var sessionsMap map[int64]int
if sessionsMax > 0 && s.sessionLimitCache != nil {
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
}
var rpmMap map[int64]int
if rpmMax > 0 && s.rpmCache != nil {
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
}
// Aggregate
var concurrencyUsed, sessionsUsed, rpmUsed int
for _, id := range accountIDs {
concurrencyUsed += concurrencyMap[id]
if sessionsMap != nil {
sessionsUsed += sessionsMap[id]
}
if rpmMap != nil {
rpmUsed += rpmMap[id]
}
}
return GroupCapacitySummary{
ConcurrencyUsed: concurrencyUsed,
ConcurrencyMax: concurrencyMax,
SessionsUsed: sessionsUsed,
SessionsMax: sessionsMax,
RPMUsed: rpmUsed,
RPMMax: rpmMax,
}, nil
}

View File

@ -27,7 +27,7 @@ type GroupRepository interface {
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重 // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
} }
// 获取账号数量 // 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id) accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account count: %w", err) return nil, fmt.Errorf("get account count: %w", err)
} }

View File

@ -19,10 +19,6 @@ import (
// 预编译正则表达式(避免每次调用重新编译) // 预编译正则表达式(避免每次调用重新编译)
var ( var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z // 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
) )
@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
} }
// RewriteUserID 重写body中的metadata.user_id // RewriteUserID 重写body中的metadata.user_id
// 输入格式user_{clientId}_account__session_{sessionUUID} // 支持旧拼接格式和新 JSON 格式的 user_id 解析,
// 输出格式user_{cachedClientID}_account_{accountUUID}_session_{newHash} // 根据 fingerprintUA 版本选择输出格式。
// //
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。 // 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" { if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil return body, nil
} }
@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return body, nil return body, nil
} }
// 匹配格式: // 解析 user_id兼容旧拼接格式和新 JSON 格式)
// 旧格式: user_{64位hex}_account__session_{uuid} parsed := ParseMetadataUserID(userID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} if parsed == nil {
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil return body, nil
} }
// matches[1] = account UUID (可能为空), matches[2] = session UUID sessionTail := parsed.SessionID // 原始session UUID
sessionTail := matches[2] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail) seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed) newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id // 根据客户端版本选择输出格式
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash} version := ExtractCLIVersion(fingerprintUA)
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
metadata["user_id"] = newUserID metadata["user_id"] = newUserID
@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
// //
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。 // 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑 // 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
if err != nil { if err != nil {
return newBody, err return newBody, err
} }
@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil return newBody, nil
} }
// 查找 _session_ 的位置,替换其后的内容 // 解析已重写的 user_id
const sessionMarker = "_session_" uidParsed := ParseMetadataUserID(userID)
idx := strings.LastIndex(userID, sessionMarker) if uidParsed == nil {
if idx == -1 {
return newBody, nil return newBody, nil
} }
@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
} }
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 // 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
slog.Debug("session_id_masking_applied", slog.Debug("session_id_masking_applied",
"account_id", account.ID, "account_id", account.ID,

View File

@ -0,0 +1,104 @@
package service
import (
"encoding/json"
"regexp"
"strings"
)
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
const NewMetadataFormatMinVersion = "2.1.78"
// ParsedUserID represents the components extracted from a metadata.user_id value.
type ParsedUserID struct {
DeviceID string // 64-char hex (or arbitrary client id)
AccountUUID string // may be empty
SessionID string // UUID
IsNewFormat bool // true if the original was JSON format
}
// legacyUserIDRegex matches the legacy user_id format:
//
// user_{64hex}_account_{optional_uuid}_session_{uuid}
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
// jsonUserID is the JSON structure for the new metadata.user_id format.
type jsonUserID struct {
DeviceID string `json:"device_id"`
AccountUUID string `json:"account_uuid"`
SessionID string `json:"session_id"`
}
// ParseMetadataUserID parses a metadata.user_id string in either format.
// Returns nil if the input cannot be parsed.
func ParseMetadataUserID(raw string) *ParsedUserID {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
// Try JSON format first (starts with '{')
if raw[0] == '{' {
var j jsonUserID
if err := json.Unmarshal([]byte(raw), &j); err != nil {
return nil
}
if j.DeviceID == "" || j.SessionID == "" {
return nil
}
return &ParsedUserID{
DeviceID: j.DeviceID,
AccountUUID: j.AccountUUID,
SessionID: j.SessionID,
IsNewFormat: true,
}
}
// Try legacy format
matches := legacyUserIDRegex.FindStringSubmatch(raw)
if matches == nil {
return nil
}
return &ParsedUserID{
DeviceID: matches[1],
AccountUUID: matches[2],
SessionID: matches[3],
IsNewFormat: false,
}
}
// FormatMetadataUserID builds a metadata.user_id string in the format
// appropriate for the given CLI version. Components are the rewritten values
// (not necessarily the originals).
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
if IsNewMetadataFormatVersion(uaVersion) {
b, _ := json.Marshal(jsonUserID{
DeviceID: deviceID,
AccountUUID: accountUUID,
SessionID: sessionID,
})
return string(b)
}
// Legacy format
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
}
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
// new JSON metadata.user_id format (>= 2.1.78).
func IsNewMetadataFormatVersion(version string) bool {
if version == "" {
return false
}
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
}
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
// Returns "" if the UA doesn't match the expected pattern.
func ExtractCLIVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}

View File

@ -0,0 +1,183 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ ParseMetadataUserID Tests ============
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
tests := []struct {
name string
raw string
}{
{"empty string", ""},
{"whitespace only", " "},
{"random text", "not-a-valid-user-id"},
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
{"invalid JSON", `{"device_id":}`},
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
})
}
}
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
// Legacy format should accept both upper and lower case hex
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(rawUpper)
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
}
// ============ FormatMetadataUserID Tests ============
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
}
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
}
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
}
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
// Legacy format with empty account UUID → double underscore
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
require.Contains(t, result, "_account__session_")
// New format with empty account UUID → empty string in JSON
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
require.Contains(t, result, `"account_uuid":""`)
}
// ============ IsNewMetadataFormatVersion Tests ============
func TestIsNewMetadataFormatVersion(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"", false},
{"2.1.77", false},
{"2.1.78", true},
{"2.1.79", true},
{"2.2.0", true},
{"3.0.0", true},
{"2.0.100", false},
{"1.9.99", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
})
}
}
// ============ Round-trip Tests ============
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
// Legacy round-trip with empty account UUID
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
// JSON round-trip with empty account UUID
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
parsed = ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
}

View File

@ -277,12 +277,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
c.JSON(http.StatusOK, chatResp) c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
Duration: time.Since(startTime),
}, nil }, nil
} }
@ -324,13 +325,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: true, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: true,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
} }
} }

View File

@ -299,12 +299,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
c.JSON(http.StatusOK, anthropicResp) c.JSON(http.StatusOK, anthropicResp)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
Duration: time.Since(startTime),
}, nil }, nil
} }
@ -347,13 +348,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
BillingModel: mappedModel, BillingModel: mappedModel,
Stream: true, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: true,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
} }
} }

View File

@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Nil(t, extractOpenAIServiceTierFromBody(nil)) require.Nil(t, extractOpenAIServiceTierFromBody(nil))
} }
func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{}
@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
RequestID: "resp_billing_model_override", RequestID: "resp_billing_model_override",
BillingModel: "gpt-5.1-codex", BillingModel: "gpt-5.1-codex",
Model: "gpt-5.1", Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
ServiceTier: &serviceTier, ServiceTier: &serviceTier,
ReasoningEffort: &reasoning, ReasoningEffort: &reasoning,
Usage: OpenAIUsage{ Usage: OpenAIUsage{
@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog) require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ServiceTier)
require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort) require.NotNil(t, usageRepo.lastLog.ReasoningEffort)

View File

@ -216,6 +216,9 @@ type OpenAIForwardResult struct {
// This is set by the Anthropic Messages conversion path where // This is set by the Anthropic Messages conversion path where
// the mapped upstream model differs from the client-facing model. // the mapped upstream model differs from the client-facing model.
BillingModel string BillingModel string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Empty when no mapping was applied (requested model was used as-is).
UpstreamModel string
// ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex".
// Nil means the request did not specify a recognized tier. // Nil means the request did not specify a recognized tier.
ServiceTier *string ServiceTier *string
@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
firstTokenMs, firstTokenMs,
wsAttempts, wsAttempts,
) )
wsResult.UpstreamModel = mappedModel
return wsResult, nil return wsResult, nil
} }
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: serviceTier, ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort, ReasoningEffort: reasoningEffort,
Stream: reqStream, Stream: reqStream,
@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: billingModel, Model: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return "" return ""
} }
} }
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}

View File

@ -0,0 +1,298 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// errSettingRepo: a SettingRepository that always returns errors on read
// ---------------------------------------------------------------------------
type errSettingRepo struct {
mockSettingRepo // embed the existing mock from backup_service_test.go
readErr error
}
func (r *errSettingRepo) GetValue(_ context.Context, _ string) (string, error) {
return "", r.readErr
}
func (r *errSettingRepo) Get(_ context.Context, _ string) (*Setting, error) {
return nil, r.readErr
}
// ---------------------------------------------------------------------------
// overloadAccountRepoStub: records SetOverloaded calls
// ---------------------------------------------------------------------------
type overloadAccountRepoStub struct {
mockAccountRepoForGemini
overloadCalls int
lastOverloadID int64
lastOverloadEnd time.Time
}
func (r *overloadAccountRepoStub) SetOverloaded(_ context.Context, id int64, until time.Time) error {
r.overloadCalls++
r.lastOverloadID = id
r.lastOverloadEnd = until
return nil
}
// ===========================================================================
// SettingService: GetOverloadCooldownSettings
// ===========================================================================
func TestGetOverloadCooldownSettings_DefaultsWhenNotSet(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ReadsFromDB(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 30})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 30, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ClampsMinValue(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 0})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 1, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_ClampsMaxValue(t *testing.T) {
repo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 999})
repo.data[SettingKeyOverloadCooldownSettings] = string(data)
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 120, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_InvalidJSON_ReturnsDefaults(t *testing.T) {
repo := newMockSettingRepo()
repo.data[SettingKeyOverloadCooldownSettings] = "not-json"
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
func TestGetOverloadCooldownSettings_EmptyValue_ReturnsDefaults(t *testing.T) {
repo := newMockSettingRepo()
repo.data[SettingKeyOverloadCooldownSettings] = ""
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes)
}
// ===========================================================================
// SettingService: SetOverloadCooldownSettings
// ===========================================================================
func TestSetOverloadCooldownSettings_Success(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: false,
CooldownMinutes: 25,
})
require.NoError(t, err)
// Verify round-trip
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 25, settings.CooldownMinutes)
}
func TestSetOverloadCooldownSettings_RejectsNil(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
err := svc.SetOverloadCooldownSettings(context.Background(), nil)
require.Error(t, err)
}
func TestSetOverloadCooldownSettings_EnabledRejectsOutOfRange(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
for _, minutes := range []int{0, -1, 121, 999} {
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: true, CooldownMinutes: minutes,
})
require.Error(t, err, "should reject enabled=true + cooldown_minutes=%d", minutes)
require.Contains(t, err.Error(), "cooldown_minutes must be between 1-120")
}
}
func TestSetOverloadCooldownSettings_DisabledNormalizesOutOfRange(t *testing.T) {
repo := newMockSettingRepo()
svc := NewSettingService(repo, &config.Config{})
// enabled=false + cooldown_minutes=0 应该保存成功值被归一化为10
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: false, CooldownMinutes: 0,
})
require.NoError(t, err, "disabled with invalid minutes should NOT be rejected")
// 验证持久化后读回来的值
settings, err := svc.GetOverloadCooldownSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.Enabled)
require.Equal(t, 10, settings.CooldownMinutes, "should be normalized to default")
}
func TestSetOverloadCooldownSettings_AcceptsBoundaries(t *testing.T) {
svc := NewSettingService(newMockSettingRepo(), &config.Config{})
for _, minutes := range []int{1, 60, 120} {
err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{
Enabled: true, CooldownMinutes: minutes,
})
require.NoError(t, err, "should accept cooldown_minutes=%d", minutes)
}
}
// ===========================================================================
// RateLimitService: handle529 behaviour
// ===========================================================================
func TestHandle529_EnabledFromDB_PausesAccount(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
settingRepo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 15})
settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data)
settingSvc := NewSettingService(settingRepo, &config.Config{})
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.Equal(t, int64(42), accountRepo.lastOverloadID)
require.WithinDuration(t, before.Add(15*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_DisabledFromDB_SkipsAccount(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
settingRepo := newMockSettingRepo()
data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 15})
settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data)
settingSvc := NewSettingService(settingRepo, &config.Config{})
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
svc.handle529(context.Background(), account)
require.Equal(t, 0, accountRepo.overloadCalls, "should NOT pause when disabled")
}
func TestHandle529_NilSettingService_FallsBackToConfig(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
cfg := &config.Config{}
cfg.RateLimit.OverloadCooldownMinutes = 20
svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil)
// NOT calling SetSettingService — remains nil
account := &Account{ID: 77, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(20*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_NilSettingService_ZeroConfig_DefaultsTen(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil)
account := &Account{ID: 88, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(10*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
func TestHandle529_DBReadError_FallsBackToConfig(t *testing.T) {
accountRepo := &overloadAccountRepoStub{}
errRepo := &errSettingRepo{readErr: context.DeadlineExceeded}
errRepo.data = make(map[string]string)
cfg := &config.Config{}
cfg.RateLimit.OverloadCooldownMinutes = 7
settingSvc := NewSettingService(errRepo, cfg)
svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil)
svc.SetSettingService(settingSvc)
account := &Account{ID: 99, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
before := time.Now()
svc.handle529(context.Background(), account)
require.Equal(t, 1, accountRepo.overloadCalls)
require.WithinDuration(t, before.Add(7*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second)
}
// ===========================================================================
// Model: defaults & JSON round-trip
// ===========================================================================
func TestDefaultOverloadCooldownSettings(t *testing.T) {
d := DefaultOverloadCooldownSettings()
require.True(t, d.Enabled)
require.Equal(t, 10, d.CooldownMinutes)
}
func TestOverloadCooldownSettings_JSONRoundTrip(t *testing.T) {
original := OverloadCooldownSettings{Enabled: false, CooldownMinutes: 42}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded OverloadCooldownSettings
require.NoError(t, json.Unmarshal(data, &decoded))
require.Equal(t, original, decoded)
// Verify JSON uses snake_case field names
var raw map[string]any
require.NoError(t, json.Unmarshal(data, &raw))
_, hasEnabled := raw["enabled"]
_, hasCooldown := raw["cooldown_minutes"]
require.True(t, hasEnabled, "JSON must use 'enabled'")
require.True(t, hasCooldown, "JSON must use 'cooldown_minutes'")
}

View File

@ -1023,11 +1023,34 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 {
} }
// handle529 处理529过载错误 // handle529 处理529过载错误
// 根据配置设置过载冷却时间 // 根据配置决定是否暂停账号调度及冷却时长
func (s *RateLimitService) handle529(ctx context.Context, account *Account) { func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes var settings *OverloadCooldownSettings
if s.settingService != nil {
var err error
settings, err = s.settingService.GetOverloadCooldownSettings(ctx)
if err != nil {
slog.Warn("overload_settings_read_failed", "account_id", account.ID, "error", err)
settings = nil
}
}
// 回退到配置文件
if settings == nil {
cooldown := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldown <= 0 {
cooldown = 10
}
settings = &OverloadCooldownSettings{Enabled: true, CooldownMinutes: cooldown}
}
if !settings.Enabled {
slog.Info("account_529_ignored", "account_id", account.ID, "reason", "overload_cooldown_disabled")
return
}
cooldownMinutes := settings.CooldownMinutes
if cooldownMinutes <= 0 { if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟 cooldownMinutes = 10
} }
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)

View File

@ -1172,6 +1172,57 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil return effective, nil
} }
// GetOverloadCooldownSettings 获取529过载冷却配置
func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOverloadCooldownSettings(), nil
}
return nil, fmt.Errorf("get overload cooldown settings: %w", err)
}
if value == "" {
return DefaultOverloadCooldownSettings(), nil
}
var settings OverloadCooldownSettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
return DefaultOverloadCooldownSettings(), nil
}
// 修正配置值范围
if settings.CooldownMinutes < 1 {
settings.CooldownMinutes = 1
}
if settings.CooldownMinutes > 120 {
settings.CooldownMinutes = 120
}
return &settings, nil
}
// SetOverloadCooldownSettings 设置529过载冷却配置
func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settings *OverloadCooldownSettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
// 禁用时修正为合法值即可,不拒绝请求
if settings.CooldownMinutes < 1 || settings.CooldownMinutes > 120 {
if settings.Enabled {
return fmt.Errorf("cooldown_minutes must be between 1-120")
}
settings.CooldownMinutes = 10 // 禁用状态下归一化为默认值
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal overload cooldown settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data))
}
// GetStreamTimeoutSettings 获取流超时处理配置 // GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) { func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings) value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)

View File

@ -222,6 +222,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"` Rules []BetaPolicyRule `json:"rules"`
} }
// OverloadCooldownSettings 529过载冷却配置
type OverloadCooldownSettings struct {
// Enabled 是否在收到529时暂停账号调度
Enabled bool `json:"enabled"`
// CooldownMinutes 冷却时长(分钟)
CooldownMinutes int `json:"cooldown_minutes"`
}
// DefaultOverloadCooldownSettings 返回默认的过载冷却配置启用10分钟
func DefaultOverloadCooldownSettings() *OverloadCooldownSettings {
return &OverloadCooldownSettings{
Enabled: true,
CooldownMinutes: 10,
}
}
// DefaultBetaPolicySettings 返回默认的 Beta 策略配置 // DefaultBetaPolicySettings 返回默认的 Beta 策略配置
func DefaultBetaPolicySettings() *BetaPolicySettings { func DefaultBetaPolicySettings() *BetaPolicySettings {
return &BetaPolicySettings{ return &BetaPolicySettings{

View File

@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([
func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
return false, nil return false, nil
} }
func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) {
return 0, nil return 0, 0, nil
} }
func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil return 0, nil

View File

@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err
func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected ExistsByName call") panic("unexpected ExistsByName call")
} }
func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected GetAccountCount call") panic("unexpected GetAccountCount call")
} }
func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri
func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
panic("unexpected List call") panic("unexpected List call")
} }
func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {

View File

@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
} }
// List 获取所有订阅(分页,支持筛选和排序) // List 获取所有订阅(分页,支持筛选和排序)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -98,6 +98,9 @@ type UsageLog struct {
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.

View File

@ -0,0 +1,21 @@
package service
import "strings"
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}
// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and
// differs from compare; otherwise nil. Used to store upstream_model only when
// it differs from the requested model.
func optionalNonEqualStringPtr(value, compare string) *string {
if value == "" || value == compare {
return nil
}
return &value
}

View File

@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error

View File

@ -486,4 +486,5 @@ var ProviderSet = wire.NewSet(
ProvideIdempotencyCleanupService, ProvideIdempotencyCleanupService,
ProvideScheduledTestService, ProvideScheduledTestService,
ProvideScheduledTestRunnerService, ProvideScheduledTestRunnerService,
NewGroupCapacityService,
) )

View File

@ -247,6 +247,12 @@ func install(c *gin.Context) {
return return
} }
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
// ========== COMPREHENSIVE INPUT VALIDATION ========== // ========== COMPREHENSIVE INPUT VALIDATION ==========
// Database validation // Database validation
if !validateHostname(req.Database.Host) { if !validateHostname(req.Database.Host) {
@ -319,13 +325,6 @@ func install(c *gin.Context) {
return return
} }
// Trim whitespace from string inputs
req.Admin.Email = strings.TrimSpace(req.Admin.Email)
req.Database.Host = strings.TrimSpace(req.Database.Host)
req.Database.User = strings.TrimSpace(req.Database.User)
req.Database.DBName = strings.TrimSpace(req.Database.DBName)
req.Redis.Host = strings.TrimSpace(req.Redis.Host)
cfg := &SetupConfig{ cfg := &SetupConfig{
Database: req.Database, Database: req.Database,
Redis: req.Redis, Redis: req.Redis,

View File

@ -180,7 +180,37 @@ func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
// Inject before </head> // Inject before </head>
headClose := []byte("</head>") headClose := []byte("</head>")
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1) result := bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
// Replace <title> with custom site name so the browser tab shows it immediately
result = injectSiteTitle(result, settingsJSON)
return result
}
// injectSiteTitle replaces the static <title> in HTML with the configured site name.
// This ensures the browser tab shows the correct title before JS executes.
func injectSiteTitle(html, settingsJSON []byte) []byte {
var cfg struct {
SiteName string `json:"site_name"`
}
if err := json.Unmarshal(settingsJSON, &cfg); err != nil || cfg.SiteName == "" {
return html
}
// Find and replace the existing <title>...</title>
titleStart := bytes.Index(html, []byte("<title>"))
titleEnd := bytes.Index(html, []byte("</title>"))
if titleStart == -1 || titleEnd == -1 || titleEnd <= titleStart {
return html
}
newTitle := []byte("<title>" + cfg.SiteName + " - AI API Gateway</title>")
var buf bytes.Buffer
buf.Write(html[:titleStart])
buf.Write(newTitle)
buf.Write(html[titleEnd+len("</title>"):])
return buf.Bytes()
} }
// replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value // replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value

View File

@ -20,6 +20,78 @@ func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestInjectSiteTitle(t *testing.T) {
t.Run("replaces_title_with_site_name", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"MyCustomSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Contains(t, string(result), "<title>MyCustomSite - AI API Gateway</title>")
assert.NotContains(t, string(result), "Sub2API")
})
t.Run("returns_unchanged_when_site_name_empty", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":""}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_site_name_missing", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{"other_field":"value"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_invalid_json", func(t *testing.T) {
html := []byte(`<html><head><title>Sub2API - AI API Gateway</title></head><body></body></html>`)
settingsJSON := []byte(`{invalid json}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_no_title_tag", func(t *testing.T) {
html := []byte(`<html><head></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"MyCustomSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Equal(t, string(html), string(result))
})
t.Run("returns_unchanged_when_title_has_attributes", func(t *testing.T) {
// The function looks for "<title>" literally, so attributes are not supported
// This is acceptable since index.html uses plain <title> without attributes
html := []byte(`<html><head><title lang="en">Sub2API</title></head><body></body></html>`)
settingsJSON := []byte(`{"site_name":"NewSite"}`)
result := injectSiteTitle(html, settingsJSON)
// Should return unchanged since <title> with attributes is not matched
assert.Equal(t, string(html), string(result))
})
t.Run("preserves_rest_of_html", func(t *testing.T) {
html := []byte(`<html><head><meta charset="UTF-8"><title>Sub2API</title><script src="app.js"></script></head><body><div id="app"></div></body></html>`)
settingsJSON := []byte(`{"site_name":"TestSite"}`)
result := injectSiteTitle(html, settingsJSON)
assert.Contains(t, string(result), `<meta charset="UTF-8">`)
assert.Contains(t, string(result), `<script src="app.js"></script>`)
assert.Contains(t, string(result), `<div id="app"></div>`)
assert.Contains(t, string(result), "<title>TestSite - AI API Gateway</title>")
})
}
func TestReplaceNoncePlaceholder(t *testing.T) { func TestReplaceNoncePlaceholder(t *testing.T) {
t.Run("replaces_single_placeholder", func(t *testing.T) { t.Run("replaces_single_placeholder", func(t *testing.T) {
html := []byte(`<script nonce="__CSP_NONCE_VALUE__">console.log('test');</script>`) html := []byte(`<script nonce="__CSP_NONCE_VALUE__">console.log('test');</script>`)

View File

@ -0,0 +1,4 @@
-- Add upstream_model field to usage_logs.
-- Stores the actual upstream model name when it differs from the requested model
-- (i.e., when model mapping is applied). NULL means no mapping was applied.
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100);

View File

@ -0,0 +1,17 @@
-- Map claude-haiku-4-5 variants target from claude-sonnet-4-5 to claude-sonnet-4-6
--
-- Only updates when the current target is exactly claude-sonnet-4-5.
-- 1. claude-haiku-4-5
UPDATE accounts
SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5}', '"claude-sonnet-4-6"')
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping'->>'claude-haiku-4-5' = 'claude-sonnet-4-5';
-- 2. claude-haiku-4-5-20251001
UPDATE accounts
SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5-20251001}', '"claude-sonnet-4-6"')
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND credentials->'model_mapping'->>'claude-haiku-4-5-20251001' = 'claude-sonnet-4-5';

View File

@ -0,0 +1,3 @@
-- Support upstream_model / mapping model distribution aggregations with time-range filters.
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_model_upstream_model
ON usage_logs (created_at, model, upstream_model);

View File

@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql`
## Migration File Structure ## Migration File Structure
```sql This project uses a custom migration runner (`internal/repository/migrations_runner.go`) that executes the full SQL file content as-is.
-- +goose Up
-- +goose StatementBegin
-- Your forward migration SQL here
-- +goose StatementEnd
-- +goose Down - Regular migrations (`*.sql`): executed in a transaction.
-- +goose StatementBegin - Non-transactional migrations (`*_notx.sql`): split by statement and executed without transaction (for `CONCURRENTLY`).
-- Your rollback migration SQL here
-- +goose StatementEnd ```sql
-- Forward-only migration (recommended)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS example_column VARCHAR(100);
``` ```
> ⚠️ Do **not** place executable "Down" SQL in the same file. The runner does not parse goose Up/Down sections and will execute all SQL statements in the file.
## Important Rules ## Important Rules
### ⚠️ Immutability Principle ### ⚠️ Immutability Principle
@ -66,9 +66,9 @@ Why?
touch migrations/018_your_change.sql touch migrations/018_your_change.sql
``` ```
2. **Write Up and Down migrations** 2. **Write forward-only migration SQL**
- Up: Apply the change - Put only the intended schema change in the file
- Down: Revert the change (should be symmetric with Up) - If rollback is needed, create a new migration file to revert
3. **Test locally** 3. **Test locally**
```bash ```bash
@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql
## Example Migration ## Example Migration
```sql ```sql
-- +goose Up
-- +goose StatementBegin
-- Add tier_id field to Gemini OAuth accounts for quota tracking -- Add tier_id field to Gemini OAuth accounts for quota tracking
UPDATE accounts UPDATE accounts
SET credentials = jsonb_set( SET credentials = jsonb_set(
@ -157,17 +155,6 @@ SET credentials = jsonb_set(
WHERE platform = 'gemini' WHERE platform = 'gemini'
AND type = 'oauth' AND type = 'oauth'
AND credentials->>'tier_id' IS NULL; AND credentials->>'tier_id' IS NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Remove tier_id field
UPDATE accounts
SET credentials = credentials - 'tier_id'
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'tier_id' = 'LEGACY';
-- +goose StatementEnd
``` ```
## Troubleshooting ## Troubleshooting
@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW());
## References ## References
- Migration runner: `internal/repository/migrations_runner.go` - Migration runner: `internal/repository/migrations_runner.go`
- Goose syntax: https://github.com/pressly/goose
- PostgreSQL docs: https://www.postgresql.org/docs/ - PostgreSQL docs: https://www.postgresql.org/docs/

View File

@ -38,7 +38,7 @@ services:
- ./data:/app/data - ./data:/app/data
# Optional: Mount custom config.yaml (uncomment and create the file first) # Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment: # Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml:ro # - ./config.yaml:/app/data/config.yaml
environment: environment:
# ======================================================================= # =======================================================================
# Auto Setup (REQUIRED for Docker deployment) # Auto Setup (REQUIRED for Docker deployment)

View File

@ -30,7 +30,7 @@ services:
- sub2api_data:/app/data - sub2api_data:/app/data
# Optional: Mount custom config.yaml (uncomment and create the file first) # Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment: # Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml:ro # - ./config.yaml:/app/data/config.yaml
environment: environment:
# ======================================================================= # =======================================================================
# Auto Setup (REQUIRED for Docker deployment) # Auto Setup (REQUIRED for Docker deployment)

View File

@ -6,7 +6,8 @@ set -e
# preventing the non-root sub2api user from writing files. # preventing the non-root sub2api user from writing files.
if [ "$(id -u)" = "0" ]; then if [ "$(id -u)" = "0" ]; then
mkdir -p /app/data mkdir -p /app/data
chown -R sub2api:sub2api /app/data # Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro)
chown -R sub2api:sub2api /app/data 2>/dev/null || true
# Re-invoke this script as sub2api so the flag-detection below # Re-invoke this script as sub2api so the flag-detection below
# also runs under the correct user. # also runs under the correct user.
exec su-exec sub2api "$0" "$@" exec su-exec sub2api "$0" "$@"

View File

@ -3,6 +3,7 @@ import { RouterView, useRouter, useRoute } from 'vue-router'
import { onMounted, onBeforeUnmount, watch } from 'vue' import { onMounted, onBeforeUnmount, watch } from 'vue'
import Toast from '@/components/common/Toast.vue' import Toast from '@/components/common/Toast.vue'
import NavigationProgress from '@/components/common/NavigationProgress.vue' import NavigationProgress from '@/components/common/NavigationProgress.vue'
import { resolveDocumentTitle } from '@/router/title'
import AnnouncementPopup from '@/components/common/AnnouncementPopup.vue' import AnnouncementPopup from '@/components/common/AnnouncementPopup.vue'
import { useAppStore, useAuthStore, useSubscriptionStore, useAnnouncementStore } from '@/stores' import { useAppStore, useAuthStore, useSubscriptionStore, useAnnouncementStore } from '@/stores'
import { getSetupStatus } from '@/api/setup' import { getSetupStatus } from '@/api/setup'
@ -104,6 +105,9 @@ onMounted(async () => {
// Load public settings into appStore (will be cached for other components) // Load public settings into appStore (will be cached for other components)
await appStore.fetchPublicSettings() await appStore.fetchPublicSettings()
// Re-resolve document title now that siteName is available
document.title = resolveDocumentTitle(route.meta.title, appStore.siteName, route.meta.titleKey as string)
}) })
</script> </script>

View File

@ -81,6 +81,7 @@ export interface ModelStatsParams {
user_id?: number user_id?: number
api_key_id?: number api_key_id?: number
model?: string model?: string
model_source?: 'requested' | 'upstream' | 'mapping'
account_id?: number account_id?: number
group_id?: number group_id?: number
request_type?: UsageRequestType request_type?: UsageRequestType
@ -162,6 +163,7 @@ export interface UserBreakdownParams {
end_date?: string end_date?: string
group_id?: number group_id?: number
model?: string model?: string
model_source?: 'requested' | 'upstream' | 'mapping'
endpoint?: string endpoint?: string
endpoint_type?: 'inbound' | 'upstream' | 'path' endpoint_type?: 'inbound' | 'upstream' | 'path'
limit?: number limit?: number

View File

@ -218,6 +218,34 @@ export async function batchSetGroupRateMultipliers(
return data return data
} }
/**
* Get usage summary (today + cumulative cost) for all groups
* @param timezone - IANA timezone string (e.g. "Asia/Shanghai")
* @returns Array of group usage summaries
*/
export async function getUsageSummary(
timezone?: string
): Promise<{ group_id: number; today_cost: number; total_cost: number }[]> {
const { data } = await apiClient.get<
{ group_id: number; today_cost: number; total_cost: number }[]
>('/admin/groups/usage-summary', {
params: timezone ? { timezone } : undefined
})
return data
}
/**
* Get capacity summary (concurrency/sessions/RPM) for all active groups
*/
export async function getCapacitySummary(): Promise<
{ group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[]
> {
const { data } = await apiClient.get<
{ group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[]
>('/admin/groups/capacity-summary')
return data
}
export const groupsAPI = { export const groupsAPI = {
list, list,
getAll, getAll,
@ -232,7 +260,9 @@ export const groupsAPI = {
getGroupRateMultipliers, getGroupRateMultipliers,
clearGroupRateMultipliers, clearGroupRateMultipliers,
batchSetGroupRateMultipliers, batchSetGroupRateMultipliers,
updateSortOrder updateSortOrder,
getUsageSummary,
getCapacitySummary
} }
export default groupsAPI export default groupsAPI

View File

@ -242,6 +242,33 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> {
return data return data
} }
// ==================== Overload Cooldown Settings ====================
/**
* Overload cooldown settings interface (529 handling)
*/
export interface OverloadCooldownSettings {
enabled: boolean
cooldown_minutes: number
}
export async function getOverloadCooldownSettings(): Promise<OverloadCooldownSettings> {
const { data } = await apiClient.get<OverloadCooldownSettings>('/admin/settings/overload-cooldown')
return data
}
export async function updateOverloadCooldownSettings(
settings: OverloadCooldownSettings
): Promise<OverloadCooldownSettings> {
const { data } = await apiClient.put<OverloadCooldownSettings>(
'/admin/settings/overload-cooldown',
settings
)
return data
}
// ==================== Stream Timeout Settings ====================
/** /**
* Stream timeout settings interface * Stream timeout settings interface
*/ */
@ -499,6 +526,8 @@ export const settingsAPI = {
getAdminApiKey, getAdminApiKey,
regenerateAdminApiKey, regenerateAdminApiKey,
deleteAdminApiKey, deleteAdminApiKey,
getOverloadCooldownSettings,
updateOverloadCooldownSettings,
getStreamTimeoutSettings, getStreamTimeoutSettings,
updateStreamTimeoutSettings, updateStreamTimeoutSettings,
getRectifierSettings, getRectifierSettings,

View File

@ -27,6 +27,7 @@ export async function list(
status?: 'active' | 'expired' | 'revoked' status?: 'active' | 'expired' | 'revoked'
user_id?: number user_id?: number
group_id?: number group_id?: number
platform?: string
sort_by?: string sort_by?: string
sort_order?: 'asc' | 'desc' sort_order?: 'asc' | 'desc'
}, },

View File

@ -82,6 +82,7 @@
:utilization="usageInfo.five_hour.utilization" :utilization="usageInfo.five_hour.utilization"
:resets-at="usageInfo.five_hour.resets_at" :resets-at="usageInfo.five_hour.resets_at"
:window-stats="usageInfo.five_hour.window_stats" :window-stats="usageInfo.five_hour.window_stats"
:show-now-when-idle="true"
color="indigo" color="indigo"
/> />
<UsageProgressBar <UsageProgressBar
@ -90,6 +91,7 @@
:utilization="usageInfo.seven_day.utilization" :utilization="usageInfo.seven_day.utilization"
:resets-at="usageInfo.seven_day.resets_at" :resets-at="usageInfo.seven_day.resets_at"
:window-stats="usageInfo.seven_day.window_stats" :window-stats="usageInfo.seven_day.window_stats"
:show-now-when-idle="true"
color="emerald" color="emerald"
/> />
</div> </div>

View File

@ -48,7 +48,7 @@
</span> </span>
<!-- Reset time --> <!-- Reset time -->
<span v-if="resetsAt" class="shrink-0 text-[10px] text-gray-400"> <span v-if="shouldShowResetTime" class="shrink-0 text-[10px] text-gray-400">
{{ formatResetTime }} {{ formatResetTime }}
</span> </span>
</div> </div>
@ -68,6 +68,7 @@ const props = defineProps<{
resetsAt?: string | null resetsAt?: string | null
color: 'indigo' | 'emerald' | 'purple' | 'amber' color: 'indigo' | 'emerald' | 'purple' | 'amber'
windowStats?: WindowStats | null windowStats?: WindowStats | null
showNowWhenIdle?: boolean
}>() }>()
const { t } = useI18n() const { t } = useI18n()
@ -139,9 +140,20 @@ const displayPercent = computed(() => {
return percent > 999 ? '>999%' : `${percent}%` return percent > 999 ? '>999%' : `${percent}%`
}) })
const shouldShowResetTime = computed(() => {
if (props.resetsAt) return true
return Boolean(props.showNowWhenIdle && props.utilization <= 0)
})
// Format reset time // Format reset time
const formatResetTime = computed(() => { const formatResetTime = computed(() => {
// For rolling windows, when utilization is 0%, treat as immediately available.
if (props.showNowWhenIdle && props.utilization <= 0) {
return '现在'
}
if (!props.resetsAt) return '-' if (!props.resetsAt) return '-'
const date = new Date(props.resetsAt) const date = new Date(props.resetsAt)
const diffMs = date.getTime() - now.value.getTime() const diffMs = date.getTime() - now.value.getTime()

View File

@ -0,0 +1,69 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import UsageProgressBar from '../UsageProgressBar.vue'
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual<typeof import('vue-i18n')>('vue-i18n')
return {
...actual,
useI18n: () => ({
t: (key: string) => key
})
}
})
describe('UsageProgressBar', () => {
beforeEach(() => {
vi.useFakeTimers()
vi.setSystemTime(new Date('2026-03-17T00:00:00Z'))
})
afterEach(() => {
vi.useRealTimers()
})
it('showNowWhenIdle=true 且利用率为 0 时显示“现在”', () => {
const wrapper = mount(UsageProgressBar, {
props: {
label: '5h',
utilization: 0,
resetsAt: '2026-03-17T02:30:00Z',
showNowWhenIdle: true,
color: 'indigo'
}
})
expect(wrapper.text()).toContain('现在')
expect(wrapper.text()).not.toContain('2h 30m')
})
it('showNowWhenIdle=true 但利用率大于 0 时显示倒计时', () => {
const wrapper = mount(UsageProgressBar, {
props: {
label: '7d',
utilization: 12,
resetsAt: '2026-03-17T02:30:00Z',
showNowWhenIdle: true,
color: 'emerald'
}
})
expect(wrapper.text()).toContain('2h 30m')
expect(wrapper.text()).not.toContain('现在')
})
it('showNowWhenIdle=false 时保持原有倒计时行为', () => {
const wrapper = mount(UsageProgressBar, {
props: {
label: '1d',
utilization: 0,
resetsAt: '2026-03-17T02:30:00Z',
showNowWhenIdle: false,
color: 'indigo'
}
})
expect(wrapper.text()).toContain('2h 30m')
expect(wrapper.text()).not.toContain('现在')
})
})

View File

@ -25,8 +25,16 @@
<span class="text-sm text-gray-900 dark:text-white">{{ row.account?.name || '-' }}</span> <span class="text-sm text-gray-900 dark:text-white">{{ row.account?.name || '-' }}</span>
</template> </template>
<template #cell-model="{ value }"> <template #cell-model="{ row }">
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span> <div v-if="row.upstream_model && row.upstream_model !== row.model" class="space-y-0.5 text-xs">
<div class="break-all font-medium text-gray-900 dark:text-white">
{{ row.model }}
</div>
<div class="break-all text-gray-500 dark:text-gray-400">
<span class="mr-0.5"></span>{{ row.upstream_model }}
</div>
</div>
<span v-else class="font-medium text-gray-900 dark:text-white">{{ row.model }}</span>
</template> </template>
<template #cell-reasoning_effort="{ row }"> <template #cell-reasoning_effort="{ row }">

View File

@ -1,10 +1,10 @@
<template> <template>
<div class="card p-4"> <div class="card p-4">
<div class="mb-4 flex items-start justify-between gap-3"> <div class="mb-4 flex items-center justify-between gap-3">
<h3 class="text-sm font-semibold text-gray-900 dark:text-white"> <h3 class="text-sm font-semibold text-gray-900 dark:text-white">
{{ title || t('usage.endpointDistribution') }} {{ title || t('usage.endpointDistribution') }}
</h3> </h3>
<div class="flex flex-col items-end gap-2"> <div class="flex flex-wrap items-center justify-end gap-2">
<div <div
v-if="showSourceToggle" v-if="showSourceToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800" class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"

View File

@ -6,7 +6,42 @@
? t('admin.dashboard.modelDistribution') ? t('admin.dashboard.modelDistribution')
: t('admin.dashboard.spendingRankingTitle') }} : t('admin.dashboard.spendingRankingTitle') }}
</h3> </h3>
<div class="flex items-center gap-2"> <div class="flex flex-wrap items-center justify-end gap-2">
<div
v-if="showSourceToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'requested'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'requested')"
>
{{ t('usage.requestedModel') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'upstream'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'upstream')"
>
{{ t('usage.upstreamModel') }}
</button>
<button
type="button"
class="rounded-md px-2.5 py-1 text-xs font-medium transition-colors"
:class="source === 'mapping'
? 'bg-white text-gray-900 shadow-sm dark:bg-dark-700 dark:text-white'
: 'text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200'"
@click="emit('update:source', 'mapping')"
>
{{ t('usage.mapping') }}
</button>
</div>
<div <div
v-if="showMetricToggle" v-if="showMetricToggle"
class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800" class="inline-flex rounded-lg border border-gray-200 bg-gray-50 p-0.5 dark:border-gray-700 dark:bg-dark-800"
@ -215,9 +250,13 @@ ChartJS.register(ArcElement, Tooltip, Legend)
const { t } = useI18n() const { t } = useI18n()
type DistributionMetric = 'tokens' | 'actual_cost' type DistributionMetric = 'tokens' | 'actual_cost'
type ModelSource = 'requested' | 'upstream' | 'mapping'
type RankingDisplayItem = UserSpendingRankingItem & { isOther?: boolean } type RankingDisplayItem = UserSpendingRankingItem & { isOther?: boolean }
const props = withDefaults(defineProps<{ const props = withDefaults(defineProps<{
modelStats: ModelStat[] modelStats: ModelStat[]
upstreamModelStats?: ModelStat[]
mappingModelStats?: ModelStat[]
source?: ModelSource
enableRankingView?: boolean enableRankingView?: boolean
rankingItems?: UserSpendingRankingItem[] rankingItems?: UserSpendingRankingItem[]
rankingTotalActualCost?: number rankingTotalActualCost?: number
@ -225,12 +264,16 @@ const props = withDefaults(defineProps<{
rankingTotalTokens?: number rankingTotalTokens?: number
loading?: boolean loading?: boolean
metric?: DistributionMetric metric?: DistributionMetric
showSourceToggle?: boolean
showMetricToggle?: boolean showMetricToggle?: boolean
rankingLoading?: boolean rankingLoading?: boolean
rankingError?: boolean rankingError?: boolean
startDate?: string startDate?: string
endDate?: string endDate?: string
}>(), { }>(), {
upstreamModelStats: () => [],
mappingModelStats: () => [],
source: 'requested',
enableRankingView: false, enableRankingView: false,
rankingItems: () => [], rankingItems: () => [],
rankingTotalActualCost: 0, rankingTotalActualCost: 0,
@ -238,6 +281,7 @@ const props = withDefaults(defineProps<{
rankingTotalTokens: 0, rankingTotalTokens: 0,
loading: false, loading: false,
metric: 'tokens', metric: 'tokens',
showSourceToggle: false,
showMetricToggle: false, showMetricToggle: false,
rankingLoading: false, rankingLoading: false,
rankingError: false rankingError: false
@ -261,6 +305,7 @@ const toggleBreakdown = async (type: string, id: string) => {
start_date: props.startDate, start_date: props.startDate,
end_date: props.endDate, end_date: props.endDate,
model: id, model: id,
model_source: props.source,
}) })
breakdownItems.value = res.users || [] breakdownItems.value = res.users || []
} catch { } catch {
@ -272,6 +317,7 @@ const toggleBreakdown = async (type: string, id: string) => {
const emit = defineEmits<{ const emit = defineEmits<{
'update:metric': [value: DistributionMetric] 'update:metric': [value: DistributionMetric]
'update:source': [value: ModelSource]
'ranking-click': [item: UserSpendingRankingItem] 'ranking-click': [item: UserSpendingRankingItem]
}>() }>()
@ -294,14 +340,19 @@ const chartColors = [
] ]
const displayModelStats = computed(() => { const displayModelStats = computed(() => {
if (!props.modelStats?.length) return [] const sourceStats = props.source === 'upstream'
? props.upstreamModelStats
: props.source === 'mapping'
? props.mappingModelStats
: props.modelStats
if (!sourceStats?.length) return []
const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens' const metricKey = props.metric === 'actual_cost' ? 'actual_cost' : 'total_tokens'
return [...props.modelStats].sort((a, b) => b[metricKey] - a[metricKey]) return [...sourceStats].sort((a, b) => b[metricKey] - a[metricKey])
}) })
const chartData = computed(() => { const chartData = computed(() => {
if (!props.modelStats?.length) return null if (!displayModelStats.value.length) return null
return { return {
labels: displayModelStats.value.map((m) => m.model), labels: displayModelStats.value.map((m) => m.model),

View File

@ -0,0 +1,84 @@
<template>
<div class="flex flex-col gap-1">
<!-- 并发槽位 -->
<div class="flex items-center gap-1">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
capacityClass(concurrencyUsed, concurrencyMax)
]"
>
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6A2.25 2.25 0 016 3.75h2.25A2.25 2.25 0 0110.5 6v2.25a2.25 2.25 0 01-2.25 2.25H6a2.25 2.25 0 01-2.25-2.25V6zM3.75 15.75A2.25 2.25 0 016 13.5h2.25a2.25 2.25 0 012.25 2.25V18a2.25 2.25 0 01-2.25 2.25H6A2.25 2.25 0 013.75 18v-2.25zM13.5 6a2.25 2.25 0 012.25-2.25H18A2.25 2.25 0 0120.25 6v2.25A2.25 2.25 0 0118 10.5h-2.25a2.25 2.25 0 01-2.25-2.25V6zM13.5 15.75a2.25 2.25 0 012.25-2.25H18a2.25 2.25 0 012.25 2.25V18A2.25 2.25 0 0118 20.25h-2.25A2.25 2.25 0 0113.5 18v-2.25z" />
</svg>
<span class="font-mono">{{ concurrencyUsed }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ concurrencyMax }}</span>
</span>
</div>
<!-- 会话数 -->
<div v-if="sessionsMax > 0" class="flex items-center gap-1">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
capacityClass(sessionsUsed, sessionsMax)
]"
>
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M15 19.128a9.38 9.38 0 002.625.372 9.337 9.337 0 004.121-.952 4.125 4.125 0 00-7.533-2.493M15 19.128v-.003c0-1.113-.285-2.16-.786-3.07M15 19.128v.106A12.318 12.318 0 018.624 21c-2.331 0-4.512-.645-6.374-1.766l-.001-.109a6.375 6.375 0 0111.964-3.07M12 6.375a3.375 3.375 0 11-6.75 0 3.375 3.375 0 016.75 0zm8.25 2.25a2.625 2.625 0 11-5.25 0 2.625 2.625 0 015.25 0z" />
</svg>
<span class="font-mono">{{ sessionsUsed }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ sessionsMax }}</span>
</span>
</div>
<!-- RPM -->
<div v-if="rpmMax > 0" class="flex items-center gap-1">
<span
:class="[
'inline-flex items-center gap-1 rounded-md px-1.5 py-0.5 text-[10px] font-medium',
capacityClass(rpmUsed, rpmMax)
]"
>
<svg class="h-2.5 w-2.5" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 6v6h4.5m4.5 0a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />
</svg>
<span class="font-mono">{{ rpmUsed }}</span>
<span class="text-gray-400 dark:text-gray-500">/</span>
<span class="font-mono">{{ rpmMax }}</span>
</span>
</div>
</div>
</template>
<script setup lang="ts">
interface Props {
concurrencyUsed: number
concurrencyMax: number
sessionsUsed: number
sessionsMax: number
rpmUsed: number
rpmMax: number
}
withDefaults(defineProps<Props>(), {
concurrencyUsed: 0,
concurrencyMax: 0,
sessionsUsed: 0,
sessionsMax: 0,
rpmUsed: 0,
rpmMax: 0
})
function capacityClass(used: number, max: number): string {
if (max > 0 && used >= max) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
if (used > 0) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
return 'bg-gray-100 text-gray-600 dark:bg-gray-800 dark:text-gray-400'
}
</script>

View File

@ -218,7 +218,7 @@ export default {
email: 'Email', email: 'Email',
password: 'Password', password: 'Password',
confirmPassword: 'Confirm Password', confirmPassword: 'Confirm Password',
passwordPlaceholder: 'Min 6 characters', passwordPlaceholder: 'Min 8 characters',
confirmPasswordPlaceholder: 'Confirm password', confirmPasswordPlaceholder: 'Confirm password',
passwordMismatch: 'Passwords do not match' passwordMismatch: 'Passwords do not match'
}, },
@ -718,11 +718,14 @@ export default {
exporting: 'Exporting...', exporting: 'Exporting...',
preparingExport: 'Preparing export...', preparingExport: 'Preparing export...',
model: 'Model', model: 'Model',
requestedModel: 'Requested',
upstreamModel: 'Upstream',
reasoningEffort: 'Reasoning Effort', reasoningEffort: 'Reasoning Effort',
endpoint: 'Endpoint', endpoint: 'Endpoint',
endpointDistribution: 'Endpoint Distribution', endpointDistribution: 'Endpoint Distribution',
inbound: 'Inbound', inbound: 'Inbound',
upstream: 'Upstream', upstream: 'Upstream',
mapping: 'Mapping',
path: 'Path', path: 'Path',
inboundEndpoint: 'Inbound Endpoint', inboundEndpoint: 'Inbound Endpoint',
upstreamEndpoint: 'Upstream Endpoint', upstreamEndpoint: 'Upstream Endpoint',
@ -1505,6 +1508,8 @@ export default {
rateMultiplier: 'Rate Multiplier', rateMultiplier: 'Rate Multiplier',
type: 'Type', type: 'Type',
accounts: 'Accounts', accounts: 'Accounts',
capacity: 'Capacity',
usage: 'Usage',
status: 'Status', status: 'Status',
actions: 'Actions', actions: 'Actions',
billingType: 'Billing Type', billingType: 'Billing Type',
@ -1513,6 +1518,12 @@ export default {
userNotes: 'Notes', userNotes: 'Notes',
userStatus: 'Status' userStatus: 'Status'
}, },
usageToday: 'Today',
usageTotal: 'Total',
accountsAvailable: 'Avail:',
accountsRateLimited: 'Limited:',
accountsTotal: 'Total:',
accountsUnit: '',
rateAndAccounts: '{rate}x rate · {count} accounts', rateAndAccounts: '{rate}x rate · {count} accounts',
accountsCount: '{count} accounts', accountsCount: '{count} accounts',
form: { form: {
@ -1694,6 +1705,7 @@ export default {
revokeSubscription: 'Revoke Subscription', revokeSubscription: 'Revoke Subscription',
allStatus: 'All Status', allStatus: 'All Status',
allGroups: 'All Groups', allGroups: 'All Groups',
allPlatforms: 'All Platforms',
daily: 'Daily', daily: 'Daily',
weekly: 'Weekly', weekly: 'Weekly',
monthly: 'Monthly', monthly: 'Monthly',
@ -1759,7 +1771,37 @@ export default {
pleaseSelectGroup: 'Please select a group', pleaseSelectGroup: 'Please select a group',
validityDaysRequired: 'Please enter a valid number of days (at least 1)', validityDaysRequired: 'Please enter a valid number of days (at least 1)',
revokeConfirm: revokeConfirm:
"Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone." "Are you sure you want to revoke the subscription for '{user}'? This action cannot be undone.",
guide: {
title: 'Subscription Management Guide',
subtitle: 'Subscription mode lets you assign time-based usage quotas to users, with daily/weekly/monthly limits. Follow these steps to get started.',
showGuide: 'Usage Guide',
step1: {
title: 'Create a Subscription Group',
line1: 'Go to "Group Management" page, click "Create Group"',
line2: 'Set billing type to "Subscription", configure daily/weekly/monthly quota limits',
line3: 'Save the group and ensure its status is "Active"',
link: 'Go to Group Management'
},
step2: {
title: 'Assign Subscription to User',
line1: 'Click the "Assign Subscription" button in the top right',
line2: 'Search for a user by email and select them',
line3: 'Choose a subscription group, set validity days, then click "Assign"'
},
step3: {
title: 'Manage Existing Subscriptions'
},
actions: {
adjust: 'Adjust',
adjustDesc: 'Extend or shorten the subscription validity period',
resetQuota: 'Reset Quota',
resetQuotaDesc: 'Reset daily/weekly/monthly usage to zero',
revoke: 'Revoke',
revokeDesc: 'Immediately terminate the subscription (irreversible)'
},
tip: 'Tip: Only groups with billing type "Subscription" and status "Active" appear in the group dropdown. If no options are available, create one in Group Management first.'
}
}, },
// Accounts // Accounts
@ -4320,6 +4362,16 @@ export default {
testFailed: 'Google Drive storage test failed' testFailed: 'Google Drive storage test failed'
} }
}, },
overloadCooldown: {
title: '529 Overload Cooldown',
description: 'Configure account scheduling pause strategy when upstream returns 529 (overloaded)',
enabled: 'Enable Overload Cooldown',
enabledHint: 'Pause account scheduling on 529 errors, auto-recover after cooldown',
cooldownMinutes: 'Cooldown Duration (minutes)',
cooldownMinutesHint: 'Duration to pause account scheduling (1-120 minutes)',
saved: 'Overload cooldown settings saved',
saveFailed: 'Failed to save overload cooldown settings'
},
streamTimeout: { streamTimeout: {
title: 'Stream Timeout Handling', title: 'Stream Timeout Handling',
description: 'Configure account handling strategy when upstream response times out', description: 'Configure account handling strategy when upstream response times out',

Some files were not shown because too many files have changed in this diff Show More