Merge origin/main into fix/deepseek-reasoning-content

This commit is contained in:
L494264Tt 2026-05-19 17:00:57 +08:00
commit 6082d02d22
219 changed files with 16693 additions and 1294 deletions

View File

@ -20,8 +20,8 @@ FROM ${NODE_IMAGE} AS frontend-builder
WORKDIR /app/frontend
# Install pnpm
RUN corepack enable && corepack prepare pnpm@latest --activate
# Install pnpm (pinned to v9 to match CI and keep builds reproducible)
RUN corepack enable && corepack prepare pnpm@9 --activate
# Install dependencies first (better caching)
COPY frontend/package.json frontend/pnpm-lock.yaml ./

View File

@ -81,7 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@ -198,7 +201,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@ -211,9 +214,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)

View File

@ -1120,6 +1120,7 @@ var (
{Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "validity_days", Type: field.TypeInt, Default: 30},
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "used_by", Type: field.TypeInt64, Nullable: true},
@ -1132,13 +1133,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "redeem_codes_groups_redeem_codes",
Columns: []*schema.Column{RedeemCodesColumns[9]},
Columns: []*schema.Column{RedeemCodesColumns[10]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "redeem_codes_users_redeem_codes",
Columns: []*schema.Column{RedeemCodesColumns[10]},
Columns: []*schema.Column{RedeemCodesColumns[11]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.SetNull,
},
@ -1152,12 +1153,17 @@ var (
{
Name: "redeemcode_used_by",
Unique: false,
Columns: []*schema.Column{RedeemCodesColumns[10]},
Columns: []*schema.Column{RedeemCodesColumns[11]},
},
{
Name: "redeemcode_group_id",
Unique: false,
Columns: []*schema.Column{RedeemCodesColumns[9]},
Columns: []*schema.Column{RedeemCodesColumns[10]},
},
{
Name: "redeemcode_expires_at",
Unique: false,
Columns: []*schema.Column{RedeemCodesColumns[8]},
},
},
}
@ -1318,6 +1324,10 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "image_input_size", Type: field.TypeString, Nullable: true, Size: 32},
{Name: "image_output_size", Type: field.TypeString, Nullable: true, Size: 32},
{Name: "image_size_source", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "image_size_breakdown", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
@ -1334,31 +1344,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[33]},
Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[34]},
Columns: []*schema.Column{UsageLogsColumns[38]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[35]},
Columns: []*schema.Column{UsageLogsColumns[39]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[36]},
Columns: []*schema.Column{UsageLogsColumns[40]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[37]},
Columns: []*schema.Column{UsageLogsColumns[41]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@ -1367,32 +1377,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[36]},
Columns: []*schema.Column{UsageLogsColumns[40]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33]},
Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[34]},
Columns: []*schema.Column{UsageLogsColumns[38]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[35]},
Columns: []*schema.Column{UsageLogsColumns[39]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[37]},
Columns: []*schema.Column{UsageLogsColumns[41]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_model",
@ -1412,17 +1422,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[40], UsageLogsColumns[36]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[36]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[39], UsageLogsColumns[36]},
},
},
}

View File

@ -28602,6 +28602,7 @@ type RedeemCodeMutation struct {
used_at *time.Time
notes *string
created_at *time.Time
expires_at *time.Time
validity_days *int
addvalidity_days *int
clearedFields map[string]struct{}
@ -29059,6 +29060,55 @@ func (m *RedeemCodeMutation) ResetCreatedAt() {
m.created_at = nil
}
// SetExpiresAt sets the "expires_at" field.
func (m *RedeemCodeMutation) SetExpiresAt(t time.Time) {
m.expires_at = &t
}
// ExpiresAt returns the value of the "expires_at" field in the mutation.
func (m *RedeemCodeMutation) ExpiresAt() (r time.Time, exists bool) {
v := m.expires_at
if v == nil {
return
}
return *v, true
}
// OldExpiresAt returns the old "expires_at" field's value of the RedeemCode entity.
// If the RedeemCode object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *RedeemCodeMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldExpiresAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
}
return oldValue.ExpiresAt, nil
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (m *RedeemCodeMutation) ClearExpiresAt() {
m.expires_at = nil
m.clearedFields[redeemcode.FieldExpiresAt] = struct{}{}
}
// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
func (m *RedeemCodeMutation) ExpiresAtCleared() bool {
_, ok := m.clearedFields[redeemcode.FieldExpiresAt]
return ok
}
// ResetExpiresAt resets all changes to the "expires_at" field.
func (m *RedeemCodeMutation) ResetExpiresAt() {
m.expires_at = nil
delete(m.clearedFields, redeemcode.FieldExpiresAt)
}
// SetGroupID sets the "group_id" field.
func (m *RedeemCodeMutation) SetGroupID(i int64) {
m.group = &i
@ -29265,7 +29315,7 @@ func (m *RedeemCodeMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *RedeemCodeMutation) Fields() []string {
fields := make([]string, 0, 10)
fields := make([]string, 0, 11)
if m.code != nil {
fields = append(fields, redeemcode.FieldCode)
}
@ -29290,6 +29340,9 @@ func (m *RedeemCodeMutation) Fields() []string {
if m.created_at != nil {
fields = append(fields, redeemcode.FieldCreatedAt)
}
if m.expires_at != nil {
fields = append(fields, redeemcode.FieldExpiresAt)
}
if m.group != nil {
fields = append(fields, redeemcode.FieldGroupID)
}
@ -29320,6 +29373,8 @@ func (m *RedeemCodeMutation) Field(name string) (ent.Value, bool) {
return m.Notes()
case redeemcode.FieldCreatedAt:
return m.CreatedAt()
case redeemcode.FieldExpiresAt:
return m.ExpiresAt()
case redeemcode.FieldGroupID:
return m.GroupID()
case redeemcode.FieldValidityDays:
@ -29349,6 +29404,8 @@ func (m *RedeemCodeMutation) OldField(ctx context.Context, name string) (ent.Val
return m.OldNotes(ctx)
case redeemcode.FieldCreatedAt:
return m.OldCreatedAt(ctx)
case redeemcode.FieldExpiresAt:
return m.OldExpiresAt(ctx)
case redeemcode.FieldGroupID:
return m.OldGroupID(ctx)
case redeemcode.FieldValidityDays:
@ -29418,6 +29475,13 @@ func (m *RedeemCodeMutation) SetField(name string, value ent.Value) error {
}
m.SetCreatedAt(v)
return nil
case redeemcode.FieldExpiresAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetExpiresAt(v)
return nil
case redeemcode.FieldGroupID:
v, ok := value.(int64)
if !ok {
@ -29498,6 +29562,9 @@ func (m *RedeemCodeMutation) ClearedFields() []string {
if m.FieldCleared(redeemcode.FieldNotes) {
fields = append(fields, redeemcode.FieldNotes)
}
if m.FieldCleared(redeemcode.FieldExpiresAt) {
fields = append(fields, redeemcode.FieldExpiresAt)
}
if m.FieldCleared(redeemcode.FieldGroupID) {
fields = append(fields, redeemcode.FieldGroupID)
}
@ -29524,6 +29591,9 @@ func (m *RedeemCodeMutation) ClearField(name string) error {
case redeemcode.FieldNotes:
m.ClearNotes()
return nil
case redeemcode.FieldExpiresAt:
m.ClearExpiresAt()
return nil
case redeemcode.FieldGroupID:
m.ClearGroupID()
return nil
@ -29559,6 +29629,9 @@ func (m *RedeemCodeMutation) ResetField(name string) error {
case redeemcode.FieldCreatedAt:
m.ResetCreatedAt()
return nil
case redeemcode.FieldExpiresAt:
m.ResetExpiresAt()
return nil
case redeemcode.FieldGroupID:
m.ResetGroupID()
return nil
@ -34260,6 +34333,10 @@ type UsageLogMutation struct {
image_count *int
addimage_count *int
image_size *string
image_input_size *string
image_output_size *string
image_size_source *string
image_size_breakdown *map[string]int
cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
@ -36202,6 +36279,202 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize)
}
// SetImageInputSize sets the "image_input_size" field.
func (m *UsageLogMutation) SetImageInputSize(s string) {
m.image_input_size = &s
}
// ImageInputSize returns the value of the "image_input_size" field in the mutation.
func (m *UsageLogMutation) ImageInputSize() (r string, exists bool) {
v := m.image_input_size
if v == nil {
return
}
return *v, true
}
// OldImageInputSize returns the old "image_input_size" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageInputSize(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageInputSize is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageInputSize requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageInputSize: %w", err)
}
return oldValue.ImageInputSize, nil
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (m *UsageLogMutation) ClearImageInputSize() {
m.image_input_size = nil
m.clearedFields[usagelog.FieldImageInputSize] = struct{}{}
}
// ImageInputSizeCleared returns if the "image_input_size" field was cleared in this mutation.
func (m *UsageLogMutation) ImageInputSizeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageInputSize]
return ok
}
// ResetImageInputSize resets all changes to the "image_input_size" field.
func (m *UsageLogMutation) ResetImageInputSize() {
m.image_input_size = nil
delete(m.clearedFields, usagelog.FieldImageInputSize)
}
// SetImageOutputSize sets the "image_output_size" field.
func (m *UsageLogMutation) SetImageOutputSize(s string) {
m.image_output_size = &s
}
// ImageOutputSize returns the value of the "image_output_size" field in the mutation.
func (m *UsageLogMutation) ImageOutputSize() (r string, exists bool) {
v := m.image_output_size
if v == nil {
return
}
return *v, true
}
// OldImageOutputSize returns the old "image_output_size" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageOutputSize(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageOutputSize is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageOutputSize requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageOutputSize: %w", err)
}
return oldValue.ImageOutputSize, nil
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (m *UsageLogMutation) ClearImageOutputSize() {
m.image_output_size = nil
m.clearedFields[usagelog.FieldImageOutputSize] = struct{}{}
}
// ImageOutputSizeCleared returns if the "image_output_size" field was cleared in this mutation.
func (m *UsageLogMutation) ImageOutputSizeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageOutputSize]
return ok
}
// ResetImageOutputSize resets all changes to the "image_output_size" field.
func (m *UsageLogMutation) ResetImageOutputSize() {
m.image_output_size = nil
delete(m.clearedFields, usagelog.FieldImageOutputSize)
}
// SetImageSizeSource sets the "image_size_source" field.
func (m *UsageLogMutation) SetImageSizeSource(s string) {
m.image_size_source = &s
}
// ImageSizeSource returns the value of the "image_size_source" field in the mutation.
func (m *UsageLogMutation) ImageSizeSource() (r string, exists bool) {
v := m.image_size_source
if v == nil {
return
}
return *v, true
}
// OldImageSizeSource returns the old "image_size_source" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageSizeSource(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageSizeSource is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageSizeSource requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageSizeSource: %w", err)
}
return oldValue.ImageSizeSource, nil
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (m *UsageLogMutation) ClearImageSizeSource() {
m.image_size_source = nil
m.clearedFields[usagelog.FieldImageSizeSource] = struct{}{}
}
// ImageSizeSourceCleared returns if the "image_size_source" field was cleared in this mutation.
func (m *UsageLogMutation) ImageSizeSourceCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageSizeSource]
return ok
}
// ResetImageSizeSource resets all changes to the "image_size_source" field.
func (m *UsageLogMutation) ResetImageSizeSource() {
m.image_size_source = nil
delete(m.clearedFields, usagelog.FieldImageSizeSource)
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (m *UsageLogMutation) SetImageSizeBreakdown(value map[string]int) {
m.image_size_breakdown = &value
}
// ImageSizeBreakdown returns the value of the "image_size_breakdown" field in the mutation.
func (m *UsageLogMutation) ImageSizeBreakdown() (r map[string]int, exists bool) {
v := m.image_size_breakdown
if v == nil {
return
}
return *v, true
}
// OldImageSizeBreakdown returns the old "image_size_breakdown" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageSizeBreakdown(ctx context.Context) (v map[string]int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageSizeBreakdown is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageSizeBreakdown requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageSizeBreakdown: %w", err)
}
return oldValue.ImageSizeBreakdown, nil
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (m *UsageLogMutation) ClearImageSizeBreakdown() {
m.image_size_breakdown = nil
m.clearedFields[usagelog.FieldImageSizeBreakdown] = struct{}{}
}
// ImageSizeBreakdownCleared returns if the "image_size_breakdown" field was cleared in this mutation.
func (m *UsageLogMutation) ImageSizeBreakdownCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageSizeBreakdown]
return ok
}
// ResetImageSizeBreakdown resets all changes to the "image_size_breakdown" field.
func (m *UsageLogMutation) ResetImageSizeBreakdown() {
m.image_size_breakdown = nil
delete(m.clearedFields, usagelog.FieldImageSizeBreakdown)
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
@ -36443,7 +36716,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 37)
fields := make([]string, 0, 41)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@ -36549,6 +36822,18 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize)
}
if m.image_input_size != nil {
fields = append(fields, usagelog.FieldImageInputSize)
}
if m.image_output_size != nil {
fields = append(fields, usagelog.FieldImageOutputSize)
}
if m.image_size_source != nil {
fields = append(fields, usagelog.FieldImageSizeSource)
}
if m.image_size_breakdown != nil {
fields = append(fields, usagelog.FieldImageSizeBreakdown)
}
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
@ -36633,6 +36918,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount()
case usagelog.FieldImageSize:
return m.ImageSize()
case usagelog.FieldImageInputSize:
return m.ImageInputSize()
case usagelog.FieldImageOutputSize:
return m.ImageOutputSize()
case usagelog.FieldImageSizeSource:
return m.ImageSizeSource()
case usagelog.FieldImageSizeBreakdown:
return m.ImageSizeBreakdown()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
@ -36716,6 +37009,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
return m.OldImageSize(ctx)
case usagelog.FieldImageInputSize:
return m.OldImageInputSize(ctx)
case usagelog.FieldImageOutputSize:
return m.OldImageOutputSize(ctx)
case usagelog.FieldImageSizeSource:
return m.OldImageSizeSource(ctx)
case usagelog.FieldImageSizeBreakdown:
return m.OldImageSizeBreakdown(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
@ -36974,6 +37275,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetImageSize(v)
return nil
case usagelog.FieldImageInputSize:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageInputSize(v)
return nil
case usagelog.FieldImageOutputSize:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageOutputSize(v)
return nil
case usagelog.FieldImageSizeSource:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageSizeSource(v)
return nil
case usagelog.FieldImageSizeBreakdown:
v, ok := value.(map[string]int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageSizeBreakdown(v)
return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
@ -37291,6 +37620,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize)
}
if m.FieldCleared(usagelog.FieldImageInputSize) {
fields = append(fields, usagelog.FieldImageInputSize)
}
if m.FieldCleared(usagelog.FieldImageOutputSize) {
fields = append(fields, usagelog.FieldImageOutputSize)
}
if m.FieldCleared(usagelog.FieldImageSizeSource) {
fields = append(fields, usagelog.FieldImageSizeSource)
}
if m.FieldCleared(usagelog.FieldImageSizeBreakdown) {
fields = append(fields, usagelog.FieldImageSizeBreakdown)
}
return fields
}
@ -37347,6 +37688,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldImageSize:
m.ClearImageSize()
return nil
case usagelog.FieldImageInputSize:
m.ClearImageInputSize()
return nil
case usagelog.FieldImageOutputSize:
m.ClearImageOutputSize()
return nil
case usagelog.FieldImageSizeSource:
m.ClearImageSizeSource()
return nil
case usagelog.FieldImageSizeBreakdown:
m.ClearImageSizeBreakdown()
return nil
}
return fmt.Errorf("unknown UsageLog nullable field %s", name)
}
@ -37460,6 +37813,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize:
m.ResetImageSize()
return nil
case usagelog.FieldImageInputSize:
m.ResetImageInputSize()
return nil
case usagelog.FieldImageOutputSize:
m.ResetImageOutputSize()
return nil
case usagelog.FieldImageSizeSource:
m.ResetImageSizeSource()
return nil
case usagelog.FieldImageSizeBreakdown:
m.ResetImageSizeBreakdown()
return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil

View File

@ -35,6 +35,8 @@ type RedeemCode struct {
Notes *string `json:"notes,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// ExpiresAt holds the value of the "expires_at" field.
ExpiresAt *time.Time `json:"expires_at,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// ValidityDays holds the value of the "validity_days" field.
@ -89,7 +91,7 @@ func (*RedeemCode) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case redeemcode.FieldCode, redeemcode.FieldType, redeemcode.FieldStatus, redeemcode.FieldNotes:
values[i] = new(sql.NullString)
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt:
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt, redeemcode.FieldExpiresAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@ -163,6 +165,13 @@ func (_m *RedeemCode) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.CreatedAt = value.Time
}
case redeemcode.FieldExpiresAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
} else if value.Valid {
_m.ExpiresAt = new(time.Time)
*_m.ExpiresAt = value.Time
}
case redeemcode.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@ -252,6 +261,11 @@ func (_m *RedeemCode) String() string {
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.ExpiresAt; v != nil {
builder.WriteString("expires_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@ -30,6 +30,8 @@ const (
FieldNotes = "notes"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldExpiresAt holds the string denoting the expires_at field in the database.
FieldExpiresAt = "expires_at"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldValidityDays holds the string denoting the validity_days field in the database.
@ -67,6 +69,7 @@ var Columns = []string{
FieldUsedAt,
FieldNotes,
FieldCreatedAt,
FieldExpiresAt,
FieldGroupID,
FieldValidityDays,
}
@ -148,6 +151,11 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByExpiresAt orders the results by the expires_at field.
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@ -95,6 +95,11 @@ func CreatedAt(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldCreatedAt, v))
}
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
func ExpiresAt(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
@ -535,6 +540,56 @@ func CreatedAtLTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLTE(FieldCreatedAt, v))
}
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
func ExpiresAtEQ(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
}
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
func ExpiresAtNEQ(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldNEQ(FieldExpiresAt, v))
}
// ExpiresAtIn applies the In predicate on the "expires_at" field.
func ExpiresAtIn(vs ...time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldIn(FieldExpiresAt, vs...))
}
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
func ExpiresAtNotIn(vs ...time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldNotIn(FieldExpiresAt, vs...))
}
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
func ExpiresAtGT(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldGT(FieldExpiresAt, v))
}
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
func ExpiresAtGTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldGTE(FieldExpiresAt, v))
}
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
func ExpiresAtLT(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLT(FieldExpiresAt, v))
}
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
func ExpiresAtLTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLTE(FieldExpiresAt, v))
}
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
func ExpiresAtIsNil() predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldIsNull(FieldExpiresAt))
}
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
func ExpiresAtNotNil() predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldNotNull(FieldExpiresAt))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))

View File

@ -128,6 +128,20 @@ func (_c *RedeemCodeCreate) SetNillableCreatedAt(v *time.Time) *RedeemCodeCreate
return _c
}
// SetExpiresAt sets the "expires_at" field.
func (_c *RedeemCodeCreate) SetExpiresAt(v time.Time) *RedeemCodeCreate {
_c.mutation.SetExpiresAt(v)
return _c
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_c *RedeemCodeCreate) SetNillableExpiresAt(v *time.Time) *RedeemCodeCreate {
if v != nil {
_c.SetExpiresAt(*v)
}
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *RedeemCodeCreate) SetGroupID(v int64) *RedeemCodeCreate {
_c.mutation.SetGroupID(v)
@ -327,6 +341,10 @@ func (_c *RedeemCodeCreate) createSpec() (*RedeemCode, *sqlgraph.CreateSpec) {
_spec.SetField(redeemcode.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.ExpiresAt(); ok {
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
_node.ExpiresAt = &value
}
if value, ok := _c.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
_node.ValidityDays = value
@ -525,6 +543,24 @@ func (u *RedeemCodeUpsert) ClearNotes() *RedeemCodeUpsert {
return u
}
// SetExpiresAt sets the "expires_at" field.
func (u *RedeemCodeUpsert) SetExpiresAt(v time.Time) *RedeemCodeUpsert {
u.Set(redeemcode.FieldExpiresAt, v)
return u
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *RedeemCodeUpsert) UpdateExpiresAt() *RedeemCodeUpsert {
u.SetExcluded(redeemcode.FieldExpiresAt)
return u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *RedeemCodeUpsert) ClearExpiresAt() *RedeemCodeUpsert {
u.SetNull(redeemcode.FieldExpiresAt)
return u
}
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsert) SetGroupID(v int64) *RedeemCodeUpsert {
u.Set(redeemcode.FieldGroupID, v)
@ -732,6 +768,27 @@ func (u *RedeemCodeUpsertOne) ClearNotes() *RedeemCodeUpsertOne {
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *RedeemCodeUpsertOne) SetExpiresAt(v time.Time) *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *RedeemCodeUpsertOne) UpdateExpiresAt() *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
s.UpdateExpiresAt()
})
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *RedeemCodeUpsertOne) ClearExpiresAt() *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
s.ClearExpiresAt()
})
}
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsertOne) SetGroupID(v int64) *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
@ -1111,6 +1168,27 @@ func (u *RedeemCodeUpsertBulk) ClearNotes() *RedeemCodeUpsertBulk {
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *RedeemCodeUpsertBulk) SetExpiresAt(v time.Time) *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *RedeemCodeUpsertBulk) UpdateExpiresAt() *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
s.UpdateExpiresAt()
})
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *RedeemCodeUpsertBulk) ClearExpiresAt() *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
s.ClearExpiresAt()
})
}
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsertBulk) SetGroupID(v int64) *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {

View File

@ -153,6 +153,26 @@ func (_u *RedeemCodeUpdate) ClearNotes() *RedeemCodeUpdate {
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *RedeemCodeUpdate) SetExpiresAt(v time.Time) *RedeemCodeUpdate {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *RedeemCodeUpdate) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdate {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (_u *RedeemCodeUpdate) ClearExpiresAt() *RedeemCodeUpdate {
_u.mutation.ClearExpiresAt()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *RedeemCodeUpdate) SetGroupID(v int64) *RedeemCodeUpdate {
_u.mutation.SetGroupID(v)
@ -321,6 +341,12 @@ func (_u *RedeemCodeUpdate) sqlSave(ctx context.Context) (_node int, err error)
if _u.mutation.NotesCleared() {
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
}
if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
}
if value, ok := _u.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
}
@ -528,6 +554,26 @@ func (_u *RedeemCodeUpdateOne) ClearNotes() *RedeemCodeUpdateOne {
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *RedeemCodeUpdateOne) SetExpiresAt(v time.Time) *RedeemCodeUpdateOne {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *RedeemCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdateOne {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (_u *RedeemCodeUpdateOne) ClearExpiresAt() *RedeemCodeUpdateOne {
_u.mutation.ClearExpiresAt()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *RedeemCodeUpdateOne) SetGroupID(v int64) *RedeemCodeUpdateOne {
_u.mutation.SetGroupID(v)
@ -726,6 +772,12 @@ func (_u *RedeemCodeUpdateOne) sqlSave(ctx context.Context) (_node *RedeemCode,
if _u.mutation.NotesCleared() {
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
}
if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
}
if value, ok := _u.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
}

View File

@ -1386,7 +1386,7 @@ func init() {
// redeemcode.DefaultCreatedAt holds the default value on creation for the created_at field.
redeemcode.DefaultCreatedAt = redeemcodeDescCreatedAt.Default.(func() time.Time)
// redeemcodeDescValidityDays is the schema descriptor for validity_days field.
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
redeemcodeDescValidityDays := redeemcodeFields[10].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
@ -1722,12 +1722,24 @@ func init() {
usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescImageInputSize is the schema descriptor for image_input_size field.
usagelogDescImageInputSize := usagelogFields[35].Descriptor()
// usagelog.ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
usagelog.ImageInputSizeValidator = usagelogDescImageInputSize.Validators[0].(func(string) error)
// usagelogDescImageOutputSize is the schema descriptor for image_output_size field.
usagelogDescImageOutputSize := usagelogFields[36].Descriptor()
// usagelog.ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
usagelog.ImageOutputSizeValidator = usagelogDescImageOutputSize.Validators[0].(func(string) error)
// usagelogDescImageSizeSource is the schema descriptor for image_size_source field.
usagelogDescImageSizeSource := usagelogFields[37].Descriptor()
// usagelog.ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
usagelog.ImageSizeSourceValidator = usagelogDescImageSizeSource.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[39].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[36].Descriptor()
usagelogDescCreatedAt := usagelogFields[40].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@ -15,12 +15,13 @@ import (
)
var authProviderTypes = map[string]struct{}{
"email": {},
"github": {},
"google": {},
"linuxdo": {},
"oidc": {},
"wechat": {},
"email": {},
"github": {},
"google": {},
"linuxdo": {},
"oidc": {},
"wechat": {},
"dingtalk": {},
}
func validateAuthProviderType(value string) error {

View File

@ -83,7 +83,7 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("unknown"))

View File

@ -63,6 +63,10 @@ func (RedeemCode) Fields() []ent.Field {
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("expires_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Int64("group_id").
Optional().
Nillable(),
@ -90,5 +94,6 @@ func (RedeemCode) Indexes() []ent.Index {
index.Fields("status"),
index.Fields("used_by"),
index.Fields("group_id"),
index.Fields("expires_at"),
}
}

View File

@ -134,6 +134,21 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
field.String("image_input_size").
MaxLen(32).
Optional().
Nillable(),
field.String("image_output_size").
MaxLen(32).
Optional().
Nillable(),
field.String("image_size_source").
MaxLen(16).
Optional().
Nillable(),
field.JSON("image_size_breakdown", map[string]int{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),

View File

@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
field.String("signup_source").
Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc", "github", "google":
case "email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google, dingtalk")
}
}).
Default("email"),

View File

@ -3,6 +3,7 @@
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
@ -92,6 +93,14 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
// ImageInputSize holds the value of the "image_input_size" field.
ImageInputSize *string `json:"image_input_size,omitempty"`
// ImageOutputSize holds the value of the "image_output_size" field.
ImageOutputSize *string `json:"image_output_size,omitempty"`
// ImageSizeSource holds the value of the "image_size_source" field.
ImageSizeSource *string `json:"image_size_source,omitempty"`
// ImageSizeBreakdown holds the value of the "image_size_breakdown" field.
ImageSizeBreakdown map[string]int `json:"image_size_breakdown,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
@ -179,13 +188,15 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagelog.FieldImageSizeBreakdown:
values[i] = new([]byte)
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldImageInputSize, usagelog.FieldImageOutputSize, usagelog.FieldImageSizeSource:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@ -434,6 +445,35 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
case usagelog.FieldImageInputSize:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field image_input_size", values[i])
} else if value.Valid {
_m.ImageInputSize = new(string)
*_m.ImageInputSize = value.String
}
case usagelog.FieldImageOutputSize:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field image_output_size", values[i])
} else if value.Valid {
_m.ImageOutputSize = new(string)
*_m.ImageOutputSize = value.String
}
case usagelog.FieldImageSizeSource:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field image_size_source", values[i])
} else if value.Valid {
_m.ImageSizeSource = new(string)
*_m.ImageSizeSource = value.String
}
case usagelog.FieldImageSizeBreakdown:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field image_size_breakdown", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ImageSizeBreakdown); err != nil {
return fmt.Errorf("unmarshal field image_size_breakdown: %w", err)
}
}
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
@ -640,6 +680,24 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ImageInputSize; v != nil {
builder.WriteString("image_input_size=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ImageOutputSize; v != nil {
builder.WriteString("image_output_size=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ImageSizeSource; v != nil {
builder.WriteString("image_size_source=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("image_size_breakdown=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageSizeBreakdown))
builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")

View File

@ -84,6 +84,14 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
// FieldImageInputSize holds the string denoting the image_input_size field in the database.
FieldImageInputSize = "image_input_size"
// FieldImageOutputSize holds the string denoting the image_output_size field in the database.
FieldImageOutputSize = "image_output_size"
// FieldImageSizeSource holds the string denoting the image_size_source field in the database.
FieldImageSizeSource = "image_size_source"
// FieldImageSizeBreakdown holds the string denoting the image_size_breakdown field in the database.
FieldImageSizeBreakdown = "image_size_breakdown"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
@ -175,6 +183,10 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldImageInputSize,
FieldImageOutputSize,
FieldImageSizeSource,
FieldImageSizeBreakdown,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@ -242,6 +254,12 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
// ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
ImageInputSizeValidator func(string) error
// ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
ImageOutputSizeValidator func(string) error
// ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
ImageSizeSourceValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
@ -431,6 +449,21 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
// ByImageInputSize orders the results by the image_input_size field.
func ByImageInputSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageInputSize, opts...).ToFunc()
}
// ByImageOutputSize orders the results by the image_output_size field.
func ByImageOutputSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageOutputSize, opts...).ToFunc()
}
// ByImageSizeSource orders the results by the image_size_source field.
func ByImageSizeSource(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSizeSource, opts...).ToFunc()
}
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()

View File

@ -230,6 +230,21 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
// ImageInputSize applies equality check predicate on the "image_input_size" field. It's identical to ImageInputSizeEQ.
func ImageInputSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
}
// ImageOutputSize applies equality check predicate on the "image_output_size" field. It's identical to ImageOutputSizeEQ.
func ImageOutputSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
}
// ImageSizeSource applies equality check predicate on the "image_size_source" field. It's identical to ImageSizeSourceEQ.
func ImageSizeSource(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
}
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
@ -1900,6 +1915,241 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
// ImageInputSizeEQ applies the EQ predicate on the "image_input_size" field.
func ImageInputSizeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
}
// ImageInputSizeNEQ applies the NEQ predicate on the "image_input_size" field.
func ImageInputSizeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldImageInputSize, v))
}
// ImageInputSizeIn applies the In predicate on the "image_input_size" field.
func ImageInputSizeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldImageInputSize, vs...))
}
// ImageInputSizeNotIn applies the NotIn predicate on the "image_input_size" field.
func ImageInputSizeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldImageInputSize, vs...))
}
// ImageInputSizeGT applies the GT predicate on the "image_input_size" field.
func ImageInputSizeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldImageInputSize, v))
}
// ImageInputSizeGTE applies the GTE predicate on the "image_input_size" field.
func ImageInputSizeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldImageInputSize, v))
}
// ImageInputSizeLT applies the LT predicate on the "image_input_size" field.
func ImageInputSizeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldImageInputSize, v))
}
// ImageInputSizeLTE applies the LTE predicate on the "image_input_size" field.
func ImageInputSizeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldImageInputSize, v))
}
// ImageInputSizeContains applies the Contains predicate on the "image_input_size" field.
func ImageInputSizeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldImageInputSize, v))
}
// ImageInputSizeHasPrefix applies the HasPrefix predicate on the "image_input_size" field.
func ImageInputSizeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageInputSize, v))
}
// ImageInputSizeHasSuffix applies the HasSuffix predicate on the "image_input_size" field.
func ImageInputSizeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageInputSize, v))
}
// ImageInputSizeIsNil applies the IsNil predicate on the "image_input_size" field.
func ImageInputSizeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageInputSize))
}
// ImageInputSizeNotNil applies the NotNil predicate on the "image_input_size" field.
func ImageInputSizeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageInputSize))
}
// ImageInputSizeEqualFold applies the EqualFold predicate on the "image_input_size" field.
func ImageInputSizeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldImageInputSize, v))
}
// ImageInputSizeContainsFold applies the ContainsFold predicate on the "image_input_size" field.
func ImageInputSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageInputSize, v))
}
// ImageOutputSizeEQ applies the EQ predicate on the "image_output_size" field.
func ImageOutputSizeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
}
// ImageOutputSizeNEQ applies the NEQ predicate on the "image_output_size" field.
func ImageOutputSizeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldImageOutputSize, v))
}
// ImageOutputSizeIn applies the In predicate on the "image_output_size" field.
func ImageOutputSizeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldImageOutputSize, vs...))
}
// ImageOutputSizeNotIn applies the NotIn predicate on the "image_output_size" field.
func ImageOutputSizeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldImageOutputSize, vs...))
}
// ImageOutputSizeGT applies the GT predicate on the "image_output_size" field.
func ImageOutputSizeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldImageOutputSize, v))
}
// ImageOutputSizeGTE applies the GTE predicate on the "image_output_size" field.
func ImageOutputSizeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldImageOutputSize, v))
}
// ImageOutputSizeLT applies the LT predicate on the "image_output_size" field.
func ImageOutputSizeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldImageOutputSize, v))
}
// ImageOutputSizeLTE applies the LTE predicate on the "image_output_size" field.
func ImageOutputSizeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldImageOutputSize, v))
}
// ImageOutputSizeContains applies the Contains predicate on the "image_output_size" field.
func ImageOutputSizeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldImageOutputSize, v))
}
// ImageOutputSizeHasPrefix applies the HasPrefix predicate on the "image_output_size" field.
func ImageOutputSizeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageOutputSize, v))
}
// ImageOutputSizeHasSuffix applies the HasSuffix predicate on the "image_output_size" field.
func ImageOutputSizeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageOutputSize, v))
}
// ImageOutputSizeIsNil applies the IsNil predicate on the "image_output_size" field.
func ImageOutputSizeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageOutputSize))
}
// ImageOutputSizeNotNil applies the NotNil predicate on the "image_output_size" field.
func ImageOutputSizeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageOutputSize))
}
// ImageOutputSizeEqualFold applies the EqualFold predicate on the "image_output_size" field.
func ImageOutputSizeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldImageOutputSize, v))
}
// ImageOutputSizeContainsFold applies the ContainsFold predicate on the "image_output_size" field.
func ImageOutputSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageOutputSize, v))
}
// ImageSizeSourceEQ applies the EQ predicate on the "image_size_source" field.
func ImageSizeSourceEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
}
// ImageSizeSourceNEQ applies the NEQ predicate on the "image_size_source" field.
func ImageSizeSourceNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldImageSizeSource, v))
}
// ImageSizeSourceIn applies the In predicate on the "image_size_source" field.
func ImageSizeSourceIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldImageSizeSource, vs...))
}
// ImageSizeSourceNotIn applies the NotIn predicate on the "image_size_source" field.
func ImageSizeSourceNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldImageSizeSource, vs...))
}
// ImageSizeSourceGT applies the GT predicate on the "image_size_source" field.
func ImageSizeSourceGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldImageSizeSource, v))
}
// ImageSizeSourceGTE applies the GTE predicate on the "image_size_source" field.
func ImageSizeSourceGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldImageSizeSource, v))
}
// ImageSizeSourceLT applies the LT predicate on the "image_size_source" field.
func ImageSizeSourceLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldImageSizeSource, v))
}
// ImageSizeSourceLTE applies the LTE predicate on the "image_size_source" field.
func ImageSizeSourceLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldImageSizeSource, v))
}
// ImageSizeSourceContains applies the Contains predicate on the "image_size_source" field.
func ImageSizeSourceContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldImageSizeSource, v))
}
// ImageSizeSourceHasPrefix applies the HasPrefix predicate on the "image_size_source" field.
func ImageSizeSourceHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSizeSource, v))
}
// ImageSizeSourceHasSuffix applies the HasSuffix predicate on the "image_size_source" field.
func ImageSizeSourceHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSizeSource, v))
}
// ImageSizeSourceIsNil applies the IsNil predicate on the "image_size_source" field.
func ImageSizeSourceIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeSource))
}
// ImageSizeSourceNotNil applies the NotNil predicate on the "image_size_source" field.
func ImageSizeSourceNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeSource))
}
// ImageSizeSourceEqualFold applies the EqualFold predicate on the "image_size_source" field.
func ImageSizeSourceEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldImageSizeSource, v))
}
// ImageSizeSourceContainsFold applies the ContainsFold predicate on the "image_size_source" field.
func ImageSizeSourceContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSizeSource, v))
}
// ImageSizeBreakdownIsNil applies the IsNil predicate on the "image_size_breakdown" field.
func ImageSizeBreakdownIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeBreakdown))
}
// ImageSizeBreakdownNotNil applies the NotNil predicate on the "image_size_breakdown" field.
func ImageSizeBreakdownNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeBreakdown))
}
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))

View File

@ -477,6 +477,54 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
// SetImageInputSize sets the "image_input_size" field.
func (_c *UsageLogCreate) SetImageInputSize(v string) *UsageLogCreate {
_c.mutation.SetImageInputSize(v)
return _c
}
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableImageInputSize(v *string) *UsageLogCreate {
if v != nil {
_c.SetImageInputSize(*v)
}
return _c
}
// SetImageOutputSize sets the "image_output_size" field.
func (_c *UsageLogCreate) SetImageOutputSize(v string) *UsageLogCreate {
_c.mutation.SetImageOutputSize(v)
return _c
}
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableImageOutputSize(v *string) *UsageLogCreate {
if v != nil {
_c.SetImageOutputSize(*v)
}
return _c
}
// SetImageSizeSource sets the "image_size_source" field.
func (_c *UsageLogCreate) SetImageSizeSource(v string) *UsageLogCreate {
_c.mutation.SetImageSizeSource(v)
return _c
}
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableImageSizeSource(v *string) *UsageLogCreate {
if v != nil {
_c.SetImageSizeSource(*v)
}
return _c
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (_c *UsageLogCreate) SetImageSizeBreakdown(v map[string]int) *UsageLogCreate {
_c.mutation.SetImageSizeBreakdown(v)
return _c
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
@ -754,6 +802,21 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _c.mutation.ImageInputSize(); ok {
if err := usagelog.ImageInputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
}
}
if v, ok := _c.mutation.ImageOutputSize(); ok {
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
}
}
if v, ok := _c.mutation.ImageSizeSource(); ok {
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
}
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
@ -916,6 +979,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
if value, ok := _c.mutation.ImageInputSize(); ok {
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
_node.ImageInputSize = &value
}
if value, ok := _c.mutation.ImageOutputSize(); ok {
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
_node.ImageOutputSize = &value
}
if value, ok := _c.mutation.ImageSizeSource(); ok {
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
_node.ImageSizeSource = &value
}
if value, ok := _c.mutation.ImageSizeBreakdown(); ok {
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
_node.ImageSizeBreakdown = value
}
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
@ -1679,6 +1758,78 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
// SetImageInputSize sets the "image_input_size" field.
func (u *UsageLogUpsert) SetImageInputSize(v string) *UsageLogUpsert {
u.Set(usagelog.FieldImageInputSize, v)
return u
}
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageInputSize() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageInputSize)
return u
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (u *UsageLogUpsert) ClearImageInputSize() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageInputSize)
return u
}
// SetImageOutputSize sets the "image_output_size" field.
func (u *UsageLogUpsert) SetImageOutputSize(v string) *UsageLogUpsert {
u.Set(usagelog.FieldImageOutputSize, v)
return u
}
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageOutputSize() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageOutputSize)
return u
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (u *UsageLogUpsert) ClearImageOutputSize() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageOutputSize)
return u
}
// SetImageSizeSource sets the "image_size_source" field.
func (u *UsageLogUpsert) SetImageSizeSource(v string) *UsageLogUpsert {
u.Set(usagelog.FieldImageSizeSource, v)
return u
}
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageSizeSource() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageSizeSource)
return u
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (u *UsageLogUpsert) ClearImageSizeSource() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageSizeSource)
return u
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (u *UsageLogUpsert) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsert {
u.Set(usagelog.FieldImageSizeBreakdown, v)
return u
}
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageSizeBreakdown() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageSizeBreakdown)
return u
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (u *UsageLogUpsert) ClearImageSizeBreakdown() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageSizeBreakdown)
return u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
@ -2457,6 +2608,90 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
// SetImageInputSize sets the "image_input_size" field.
func (u *UsageLogUpsertOne) SetImageInputSize(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageInputSize(v)
})
}
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageInputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageInputSize()
})
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (u *UsageLogUpsertOne) ClearImageInputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageInputSize()
})
}
// SetImageOutputSize sets the "image_output_size" field.
func (u *UsageLogUpsertOne) SetImageOutputSize(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageOutputSize(v)
})
}
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageOutputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageOutputSize()
})
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (u *UsageLogUpsertOne) ClearImageOutputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageOutputSize()
})
}
// SetImageSizeSource sets the "image_size_source" field.
func (u *UsageLogUpsertOne) SetImageSizeSource(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeSource(v)
})
}
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageSizeSource() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeSource()
})
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (u *UsageLogUpsertOne) ClearImageSizeSource() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeSource()
})
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (u *UsageLogUpsertOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeBreakdown(v)
})
}
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageSizeBreakdown() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeBreakdown()
})
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (u *UsageLogUpsertOne) ClearImageSizeBreakdown() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeBreakdown()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@ -3403,6 +3638,90 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
// SetImageInputSize sets the "image_input_size" field.
func (u *UsageLogUpsertBulk) SetImageInputSize(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageInputSize(v)
})
}
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageInputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageInputSize()
})
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (u *UsageLogUpsertBulk) ClearImageInputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageInputSize()
})
}
// SetImageOutputSize sets the "image_output_size" field.
func (u *UsageLogUpsertBulk) SetImageOutputSize(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageOutputSize(v)
})
}
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageOutputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageOutputSize()
})
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (u *UsageLogUpsertBulk) ClearImageOutputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageOutputSize()
})
}
// SetImageSizeSource sets the "image_size_source" field.
func (u *UsageLogUpsertBulk) SetImageSizeSource(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeSource(v)
})
}
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageSizeSource() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeSource()
})
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (u *UsageLogUpsertBulk) ClearImageSizeSource() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeSource()
})
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (u *UsageLogUpsertBulk) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeBreakdown(v)
})
}
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageSizeBreakdown() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeBreakdown()
})
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (u *UsageLogUpsertBulk) ClearImageSizeBreakdown() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeBreakdown()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@ -739,6 +739,78 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
// SetImageInputSize sets the "image_input_size" field.
func (_u *UsageLogUpdate) SetImageInputSize(v string) *UsageLogUpdate {
_u.mutation.SetImageInputSize(v)
return _u
}
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableImageInputSize(v *string) *UsageLogUpdate {
if v != nil {
_u.SetImageInputSize(*v)
}
return _u
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (_u *UsageLogUpdate) ClearImageInputSize() *UsageLogUpdate {
_u.mutation.ClearImageInputSize()
return _u
}
// SetImageOutputSize sets the "image_output_size" field.
func (_u *UsageLogUpdate) SetImageOutputSize(v string) *UsageLogUpdate {
_u.mutation.SetImageOutputSize(v)
return _u
}
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableImageOutputSize(v *string) *UsageLogUpdate {
if v != nil {
_u.SetImageOutputSize(*v)
}
return _u
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (_u *UsageLogUpdate) ClearImageOutputSize() *UsageLogUpdate {
_u.mutation.ClearImageOutputSize()
return _u
}
// SetImageSizeSource sets the "image_size_source" field.
func (_u *UsageLogUpdate) SetImageSizeSource(v string) *UsageLogUpdate {
_u.mutation.SetImageSizeSource(v)
return _u
}
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableImageSizeSource(v *string) *UsageLogUpdate {
if v != nil {
_u.SetImageSizeSource(*v)
}
return _u
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (_u *UsageLogUpdate) ClearImageSizeSource() *UsageLogUpdate {
_u.mutation.ClearImageSizeSource()
return _u
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (_u *UsageLogUpdate) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdate {
_u.mutation.SetImageSizeBreakdown(v)
return _u
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (_u *UsageLogUpdate) ClearImageSizeBreakdown() *UsageLogUpdate {
_u.mutation.ClearImageSizeBreakdown()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
@ -892,6 +964,21 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageInputSize(); ok {
if err := usagelog.ImageInputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageOutputSize(); ok {
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSizeSource(); ok {
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@ -1099,6 +1186,30 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.ImageInputSize(); ok {
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
}
if _u.mutation.ImageInputSizeCleared() {
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageOutputSize(); ok {
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
}
if _u.mutation.ImageOutputSizeCleared() {
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeSource(); ok {
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
}
if _u.mutation.ImageSizeSourceCleared() {
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
}
if _u.mutation.ImageSizeBreakdownCleared() {
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
@ -1974,6 +2085,78 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
// SetImageInputSize sets the "image_input_size" field.
func (_u *UsageLogUpdateOne) SetImageInputSize(v string) *UsageLogUpdateOne {
_u.mutation.SetImageInputSize(v)
return _u
}
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableImageInputSize(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetImageInputSize(*v)
}
return _u
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (_u *UsageLogUpdateOne) ClearImageInputSize() *UsageLogUpdateOne {
_u.mutation.ClearImageInputSize()
return _u
}
// SetImageOutputSize sets the "image_output_size" field.
func (_u *UsageLogUpdateOne) SetImageOutputSize(v string) *UsageLogUpdateOne {
_u.mutation.SetImageOutputSize(v)
return _u
}
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableImageOutputSize(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetImageOutputSize(*v)
}
return _u
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (_u *UsageLogUpdateOne) ClearImageOutputSize() *UsageLogUpdateOne {
_u.mutation.ClearImageOutputSize()
return _u
}
// SetImageSizeSource sets the "image_size_source" field.
func (_u *UsageLogUpdateOne) SetImageSizeSource(v string) *UsageLogUpdateOne {
_u.mutation.SetImageSizeSource(v)
return _u
}
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableImageSizeSource(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetImageSizeSource(*v)
}
return _u
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (_u *UsageLogUpdateOne) ClearImageSizeSource() *UsageLogUpdateOne {
_u.mutation.ClearImageSizeSource()
return _u
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (_u *UsageLogUpdateOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdateOne {
_u.mutation.SetImageSizeBreakdown(v)
return _u
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (_u *UsageLogUpdateOne) ClearImageSizeBreakdown() *UsageLogUpdateOne {
_u.mutation.ClearImageSizeBreakdown()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
@ -2140,6 +2323,21 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageInputSize(); ok {
if err := usagelog.ImageInputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageOutputSize(); ok {
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSizeSource(); ok {
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@ -2364,6 +2562,30 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.ImageInputSize(); ok {
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
}
if _u.mutation.ImageInputSizeCleared() {
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageOutputSize(); ok {
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
}
if _u.mutation.ImageOutputSizeCleared() {
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeSource(); ok {
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
}
if _u.mutation.ImageSizeSourceCleared() {
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
}
if _u.mutation.ImageSizeBreakdownCleared() {
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}

View File

@ -216,6 +216,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -249,6 +251,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -278,6 +282,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@ -310,6 +316,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@ -72,6 +72,7 @@ type Config struct {
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
DingTalk DingTalkConnectConfig `mapstructure:"dingtalk_connect"`
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
Default DefaultConfig `mapstructure:"default"`
@ -242,6 +243,47 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DingTalkConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
// 平台底座 + 业务行为
DingTalkAppKind string `mapstructure:"dingtalk_app_kind"` // 仅 "internal_app"V4 fail-closed
AppType string `mapstructure:"app_type"` // "public" (default) | "internal"
// Corp 限定none | internal_only
CorpRestrictionPolicy string `mapstructure:"corp_restriction_policy"`
InternalCorpID string `mapstructure:"internal_corp_id"`
BypassRegistration bool `mapstructure:"bypass_registration"`
SyncCorpEmail bool `mapstructure:"sync_corp_email"`
SyncDisplayName bool `mapstructure:"sync_display_name"`
SyncDept bool `mapstructure:"sync_dept"`
SyncCorpEmailAttrKey string `mapstructure:"sync_corp_email_attr_key"`
SyncDisplayNameAttrKey string `mapstructure:"sync_display_name_attr_key"`
SyncDeptAttrKey string `mapstructure:"sync_dept_attr_key"`
SyncCorpEmailAttrName string `mapstructure:"sync_corp_email_attr_name"`
SyncDisplayNameAttrName string `mapstructure:"sync_display_name_attr_name"`
SyncDeptAttrName string `mapstructure:"sync_dept_attr_name"`
// 邮箱 + Username
RequireEmail bool `mapstructure:"require_email"`
UsernameOverwritePolicy string `mapstructure:"username_overwrite_policy"`
// Attribute私有版扩展点开源版仅声明
UsernameAttributeKey string `mapstructure:"username_attribute_key"`
EnableAttributeMatching bool `mapstructure:"enable_attribute_matching"`
EnableAttributeSync bool `mapstructure:"enable_attribute_sync"`
AttributeSyncFields []string `mapstructure:"attribute_sync_fields"`
AttributeSyncOverwritePolicy string `mapstructure:"attribute_sync_overwrite_policy"`
}
type EmailOAuthProviderConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
@ -1536,6 +1578,19 @@ func setDefaults() {
viper.SetDefault("oidc_connect.userinfo_id_path", "")
viper.SetDefault("oidc_connect.userinfo_username_path", "")
// DingTalk Connect OAuth 登录
viper.SetDefault("dingtalk_connect.enabled", false)
viper.SetDefault("dingtalk_connect.authorize_url", "https://login.dingtalk.com/oauth2/auth")
viper.SetDefault("dingtalk_connect.token_url", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken")
viper.SetDefault("dingtalk_connect.userinfo_url", "https://api.dingtalk.com/v1.0/contact/users/me")
viper.SetDefault("dingtalk_connect.scopes", "openid")
viper.SetDefault("dingtalk_connect.frontend_redirect_url", "/auth/dingtalk/callback")
viper.SetDefault("dingtalk_connect.dingtalk_app_kind", "internal_app")
viper.SetDefault("dingtalk_connect.app_type", "public")
viper.SetDefault("dingtalk_connect.corp_restriction_policy", "none")
viper.SetDefault("dingtalk_connect.require_email", true)
viper.SetDefault("dingtalk_connect.username_overwrite_policy", "if_empty")
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
@ -2608,6 +2663,9 @@ func (c *Config) Validate() error {
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
if err := ValidateDingTalkConfig(c.DingTalk); err != nil {
return fmt.Errorf("dingtalk_connect: %w", err)
}
return nil
}

View File

@ -0,0 +1,30 @@
// Package config 包含钉钉连接配置的校验逻辑。
//
// internal_only 模式安全模型(方案 A
// 不再要求 admin 填写 InternalCorpID 做二次 corpID 比对。
// 安全边界由钉钉"企业内部应用"类型本身保证——只有应用所属企业的员工才能完成 OAuth
// 因此 ValidateDingTalkConfig 只要求 app_type=internalV1不再要求 InternalCorpID 非空(原 V3 已删除)。
// InternalCorpID 字段保留admin 可选填若填写checkDingTalkCorpAllowed 不会使用它做约束。
package config
import "errors"
var (
ErrDingTalkV1AppTypeMismatch = errors.New("dingtalk: internal_only requires app_type=internal")
ErrDingTalkV4InvalidAppKind = errors.New("dingtalk: dingtalk_app_kind must be internal_app")
)
func ValidateDingTalkConfig(cfg DingTalkConnectConfig) error {
if !cfg.Enabled {
return nil
}
if cfg.DingTalkAppKind != "internal_app" {
return ErrDingTalkV4InvalidAppKind
}
if cfg.CorpRestrictionPolicy == "internal_only" {
if cfg.AppType != "internal" {
return ErrDingTalkV1AppTypeMismatch
}
}
return nil
}

View File

@ -0,0 +1,53 @@
package config
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidateDingTalkConfig_Disabled_Skip(t *testing.T) {
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{Enabled: false}))
}
func TestValidateDingTalkConfig_V4_DingTalkAppKind(t *testing.T) {
err := ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "third_party_enterprise_app",
CorpRestrictionPolicy: "none",
})
require.ErrorIs(t, err, ErrDingTalkV4InvalidAppKind)
}
func TestValidateDingTalkConfig_V1_InternalOnlyRequiresInternalAppType(t *testing.T) {
err := ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app",
AppType: "public",
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "dingABC",
})
require.ErrorIs(t, err, ErrDingTalkV1AppTypeMismatch)
}
// TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A
// internal_only 策略下InternalCorpID="" 应通过校验(企业隔离由钉钉 AppType=internal 保证)。
func TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
err := ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app",
AppType: "internal",
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "",
})
require.NoError(t, err)
}
func TestValidateDingTalkConfig_HappyPath_None(t *testing.T) {
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app",
AppType: "public",
CorpRestrictionPolicy: "none",
}))
}

View File

@ -43,6 +43,9 @@ type DataProxy struct {
Status string `json:"status"`
}
// DataAccount 是管理员显式备份导出使用的账号结构,故意不走 dto.Account 的脱敏路径,
// Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本,
// 应新增独立结构而非修改这里。
type DataAccount struct {
Name string `json:"name"`
Notes *string `json:"notes,omitempty"`

View File

@ -1994,6 +1994,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models)
}
// SyncUpstreamModels handles syncing live supported models from an account's upstream.
// POST /api/v1/admin/accounts/:id/models/sync-upstream
func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
if h.accountTestService == nil {
response.InternalError(c, "Account test service is not configured")
return
}
models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account)
if err != nil {
var syncErr *service.UpstreamModelSyncError
if errors.As(err, &syncErr) {
switch syncErr.Kind {
case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported:
response.BadRequest(c, syncErr.SafeMessage())
default:
slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind)
response.Error(c, http.StatusBadGateway, syncErr.SafeMessage())
}
return
}
slog.Warn("sync_upstream_models_failed", "account_id", accountID)
response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream")
return
}
response.Success(c, gin.H{"models": models})
}
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) {

View File

@ -3,10 +3,14 @@ package admin
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@ -33,6 +37,39 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
return router
}
type syncUpstreamHTTPUpstream struct {
resp *http.Response
err error
}
func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if u.err != nil {
return nil, u.err
}
return u.resp, nil
}
func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountTestSvc := service.NewAccountTestService(
nil,
nil,
nil,
nil,
upstream,
&config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
nil,
)
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels)
return router
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
@ -103,3 +140,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 44,
Name: "openai-apikey-missing-key",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"base_url": "https://openai.example.com/v1",
},
},
}
router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "No OpenAI API key is available")
}
func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 45,
Name: "openai-apikey-upstream-error",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"api_key": "openai-key",
"base_url": "https://openai.example.com/v1",
},
},
}
upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{
StatusCode: http.StatusBadGateway,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
}}
router := setupSyncUpstreamModelsRouter(svc, upstream)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadGateway, rec.Code)
require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502")
require.NotContains(t, rec.Body.String(), "SECRET_TOKEN")
}

View File

@ -546,9 +546,14 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
// cacheKey 必须包含当日日期,否则跨午夜后 30s 内会复用昨天的 "today_*" 结果。
keyRaw, _ := json.Marshal(struct {
V int `json:"v"`
Day string `json:"day"`
UserIDs []int64 `json:"user_ids"`
}{
V: 2, // bump 当响应结构变化(如加入 by_platform 时)
Day: timezone.Today().Format("2006-01-02"),
UserIDs: userIDs,
})
cacheKey := string(keyRaw)

View File

@ -8,6 +8,7 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@ -33,23 +34,51 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
ExpiresAt *time.Time `json:"expires_at"`
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
Notes string `json:"notes"`
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
Notes string `json:"notes"`
ExpiresAt *time.Time `json:"expires_at"`
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
}
func resolveRedeemCodeExpiresAt(expiresAt *time.Time, expiresInDays *int) (*time.Time, error) {
if expiresAt != nil && expiresInDays != nil {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRY_CONFLICT", "expires_at and expires_in_days cannot both be set")
}
now := time.Now().UTC()
if expiresInDays != nil {
if *expiresInDays <= 0 {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_IN_DAYS_INVALID", "expires_in_days must be greater than zero")
}
expires := now.AddDate(0, 0, *expiresInDays)
return &expires, nil
}
if expiresAt == nil {
return nil, nil
}
expires := expiresAt.UTC()
if !expires.After(now) {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
}
return &expires, nil
}
// List handles listing all redeem codes with pagination
@ -107,6 +136,12 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
@ -114,6 +149,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
ExpiresAt: expiresAt,
})
if execErr != nil {
return nil, execErr
@ -158,6 +194,12 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
}
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
if err == nil {
@ -175,6 +217,7 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
ExpiresAt: expiresAt,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.
@ -199,6 +242,9 @@ func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, exis
}
// If previous run created the code but crashed before redeem, redeem it now.
if existing.IsExpired() {
return nil, service.ErrRedeemCodeExpired
}
if existing.CanUse() {
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
if err == nil {
@ -321,7 +367,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "expires_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
@ -340,6 +386,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
expiresAt := ""
if code.ExpiresAt != nil {
expiresAt = code.ExpiresAt.Format("2006-01-02 15:04:05")
}
if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
@ -349,6 +399,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy,
usedByEmail,
usedAt,
expiresAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@ -139,3 +140,33 @@ func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}
func TestResolveRedeemCodeExpiresAt_FromDays(t *testing.T) {
days := 3
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
require.NoError(t, err)
require.NotNil(t, expiresAt)
require.WithinDuration(t, time.Now().UTC().AddDate(0, 0, days), *expiresAt, 2*time.Second)
}
func TestResolveRedeemCodeExpiresAt_RejectsPastAbsoluteTime(t *testing.T) {
past := time.Now().UTC().Add(-time.Minute)
expiresAt, err := resolveRedeemCodeExpiresAt(&past, nil)
require.Error(t, err)
require.Nil(t, expiresAt)
}
func TestResolveRedeemCodeExpiresAt_RejectsNonPositiveDays(t *testing.T) {
days := 0
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
require.Error(t, err)
require.Nil(t, expiresAt)
}
func TestResolveRedeemCodeExpiresAt_RejectsConflictingInputs(t *testing.T) {
future := time.Now().UTC().Add(time.Hour)
days := 3
expiresAt, err := resolveRedeemCodeExpiresAt(&future, &days)
require.Error(t, err)
require.Nil(t, expiresAt)
}

View File

@ -1,9 +1,11 @@
package admin
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
@ -60,10 +62,11 @@ type SettingHandler struct {
opsService *service.OpsService
paymentConfigService *service.PaymentConfigService
paymentService *service.PaymentService
userAttributeService *service.UserAttributeService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
@ -71,6 +74,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
opsService: opsService,
paymentConfigService: paymentConfigService,
paymentService: paymentService,
userAttributeService: userAttributeService,
}
}
@ -135,6 +139,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
DingTalkConnectEnabled: settings.DingTalkConnectEnabled,
DingTalkConnectClientID: settings.DingTalkConnectClientID,
DingTalkConnectClientSecretConfigured: settings.DingTalkConnectClientSecretConfigured,
DingTalkConnectRedirectURL: settings.DingTalkConnectRedirectURL,
DingTalkConnectCorpRestrictionPolicy: settings.DingTalkConnectCorpRestrictionPolicy,
DingTalkConnectInternalCorpID: settings.DingTalkConnectInternalCorpID,
DingTalkConnectBypassRegistration: settings.DingTalkConnectBypassRegistration,
DingTalkConnectSyncCorpEmail: settings.DingTalkConnectSyncCorpEmail,
DingTalkConnectSyncDisplayName: settings.DingTalkConnectSyncDisplayName,
DingTalkConnectSyncDept: settings.DingTalkConnectSyncDept,
DingTalkConnectSyncCorpEmailAttrKey: settings.DingTalkConnectSyncCorpEmailAttrKey,
DingTalkConnectSyncDisplayNameAttrKey: settings.DingTalkConnectSyncDisplayNameAttrKey,
DingTalkConnectSyncDeptAttrKey: settings.DingTalkConnectSyncDeptAttrKey,
DingTalkConnectSyncCorpEmailAttrName: settings.DingTalkConnectSyncCorpEmailAttrName,
DingTalkConnectSyncDisplayNameAttrName: settings.DingTalkConnectSyncDisplayNameAttrName,
DingTalkConnectSyncDeptAttrName: settings.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: settings.WeChatConnectEnabled,
WeChatConnectAppID: settings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
@ -376,6 +396,24 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// DingTalk Connect OAuth 登录
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
DingTalkConnectClientSecret string `json:"dingtalk_connect_client_secret"`
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
// WeChat Connect OAuth 登录
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
@ -446,45 +484,50 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
AuthSourceDefaultDingTalkBalance *float64 `json:"auth_source_default_dingtalk_balance"`
AuthSourceDefaultDingTalkConcurrency *int `json:"auth_source_default_dingtalk_concurrency"`
AuthSourceDefaultDingTalkSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_dingtalk_subscriptions"`
AuthSourceDefaultDingTalkGrantOnSignup *bool `json:"auth_source_default_dingtalk_grant_on_signup"`
AuthSourceDefaultDingTalkGrantOnFirstBind *bool `json:"auth_source_default_dingtalk_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@ -661,6 +704,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
req.AuthSourceDefaultDingTalkSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultDingTalkSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@ -777,6 +821,100 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// DingTalk Connect 参数验证
// 防御性:任何写入路径上把已废弃的 corp_restriction_policy=whitelist 入参 coerce 为 none
// 避免任何直连 admin API 的客户端把死值写回 DB前端 UI 已无此选项)。
req.DingTalkConnectCorpRestrictionPolicy = service.CoerceDingTalkCorpPolicyForWrite(req.DingTalkConnectCorpRestrictionPolicy)
if req.DingTalkConnectEnabled {
req.DingTalkConnectClientID = strings.TrimSpace(req.DingTalkConnectClientID)
req.DingTalkConnectClientSecret = strings.TrimSpace(req.DingTalkConnectClientSecret)
req.DingTalkConnectRedirectURL = strings.TrimSpace(req.DingTalkConnectRedirectURL)
req.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(req.DingTalkConnectCorpRestrictionPolicy)
req.DingTalkConnectInternalCorpID = strings.TrimSpace(req.DingTalkConnectInternalCorpID)
if req.DingTalkConnectClientID == "" {
response.BadRequest(c, "DingTalk Client ID is required when enabled")
return
}
if req.DingTalkConnectRedirectURL == "" {
response.BadRequest(c, "DingTalk Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.DingTalkConnectRedirectURL); err != nil {
response.BadRequest(c, "DingTalk Redirect URL must be an absolute http(s) URL")
return
}
// 如果未提供 client_secret则保留现有值如有
if req.DingTalkConnectClientSecret == "" {
if previousSettings.DingTalkConnectClientSecret == "" {
response.BadRequest(c, "DingTalk Client Secret is required when enabled")
return
}
req.DingTalkConnectClientSecret = previousSettings.DingTalkConnectClientSecret
}
// Corp 策略校验V1/V4 fail-closed
dingTalkCfg := config.DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app", // 硬编码settings 层仅支持 internal_app
AppType: "internal", // 对于 internal_only 策略的默认值
CorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
InternalCorpID: req.DingTalkConnectInternalCorpID,
}
// 若未填 corp_restriction_policy保留已有配置
if dingTalkCfg.CorpRestrictionPolicy == "" {
dingTalkCfg.CorpRestrictionPolicy = previousSettings.DingTalkConnectCorpRestrictionPolicy
}
// 对于 internal_only 策略app_type 必须为 internalV1 校验)
if dingTalkCfg.CorpRestrictionPolicy == "internal_only" {
dingTalkCfg.AppType = "internal"
} else {
dingTalkCfg.AppType = "public"
}
if err := config.ValidateDingTalkConfig(dingTalkCfg); err != nil {
response.ErrorWithDetails(c, http.StatusBadRequest, err.Error(), mapDingTalkValidateError(err), nil)
return
}
// bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制为 false
// 防止 admin 在切换 policy 时把 bypass 残留在 DB 中(前端 UI 也已隐藏该开关)。
if dingTalkCfg.CorpRestrictionPolicy != "internal_only" {
req.DingTalkConnectBypassRegistration = false
// 身份同步三开关同理:仅 internal_only 模式下有意义,其它策略强制 false。
req.DingTalkConnectSyncCorpEmail = false
req.DingTalkConnectSyncDisplayName = false
req.DingTalkConnectSyncDept = false
}
// 身份同步目标 attr keytrimSpace + 空值 fallback 到默认值
req.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrKey)
if req.DingTalkConnectSyncCorpEmailAttrKey == "" {
req.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email"
}
req.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrKey)
if req.DingTalkConnectSyncDisplayNameAttrKey == "" {
req.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name"
}
req.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrKey)
if req.DingTalkConnectSyncDeptAttrKey == "" {
req.DingTalkConnectSyncDeptAttrKey = "dingtalk_department"
}
// 身份同步目标 attr 显示名称trim + 空值 fallback 到默认中文名
req.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrName)
if req.DingTalkConnectSyncCorpEmailAttrName == "" {
req.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱"
}
req.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrName)
if req.DingTalkConnectSyncDisplayNameAttrName == "" {
req.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名"
}
req.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrName)
if req.DingTalkConnectSyncDeptAttrName == "" {
req.DingTalkConnectSyncDeptAttrName = "钉钉部门"
}
}
if req.WeChatConnectEnabled {
req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
@ -1272,113 +1410,129 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
WeChatConnectEnabled: req.WeChatConnectEnabled,
WeChatConnectAppID: req.WeChatConnectAppID,
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
WeChatConnectMode: req.WeChatConnectMode,
WeChatConnectScopes: req.WeChatConnectScopes,
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
OIDCConnectScopes: req.OIDCConnectScopes,
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
GitHubOAuthClientID: req.GitHubOAuthClientID,
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
GoogleOAuthClientID: req.GoogleOAuthClientID,
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
TableDefaultPageSize: req.TableDefaultPageSize,
TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
DingTalkConnectEnabled: req.DingTalkConnectEnabled,
DingTalkConnectClientID: req.DingTalkConnectClientID,
DingTalkConnectClientSecret: req.DingTalkConnectClientSecret,
DingTalkConnectRedirectURL: req.DingTalkConnectRedirectURL,
DingTalkConnectCorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
DingTalkConnectInternalCorpID: req.DingTalkConnectInternalCorpID,
DingTalkConnectBypassRegistration: req.DingTalkConnectBypassRegistration,
DingTalkConnectSyncCorpEmail: req.DingTalkConnectSyncCorpEmail,
DingTalkConnectSyncDisplayName: req.DingTalkConnectSyncDisplayName,
DingTalkConnectSyncDept: req.DingTalkConnectSyncDept,
DingTalkConnectSyncCorpEmailAttrKey: req.DingTalkConnectSyncCorpEmailAttrKey,
DingTalkConnectSyncDisplayNameAttrKey: req.DingTalkConnectSyncDisplayNameAttrKey,
DingTalkConnectSyncDeptAttrKey: req.DingTalkConnectSyncDeptAttrKey,
DingTalkConnectSyncCorpEmailAttrName: req.DingTalkConnectSyncCorpEmailAttrName,
DingTalkConnectSyncDisplayNameAttrName: req.DingTalkConnectSyncDisplayNameAttrName,
DingTalkConnectSyncDeptAttrName: req.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: req.WeChatConnectEnabled,
WeChatConnectAppID: req.WeChatConnectAppID,
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
WeChatConnectMode: req.WeChatConnectMode,
WeChatConnectScopes: req.WeChatConnectScopes,
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
OIDCConnectScopes: req.OIDCConnectScopes,
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
GitHubOAuthClientID: req.GitHubOAuthClientID,
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
GoogleOAuthClientID: req.GoogleOAuthClientID,
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
TableDefaultPageSize: req.TableDefaultPageSize,
TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@ -1574,6 +1728,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
},
DingTalk: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultDingTalkConcurrency, previousAuthSourceDefaults.DingTalk.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
@ -1632,6 +1793,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
h.ensureDingTalkSyncAttributes(c.Request.Context(), updatedSettings)
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
@ -1682,6 +1844,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
DingTalkConnectEnabled: updatedSettings.DingTalkConnectEnabled,
DingTalkConnectClientID: updatedSettings.DingTalkConnectClientID,
DingTalkConnectClientSecretConfigured: updatedSettings.DingTalkConnectClientSecretConfigured,
DingTalkConnectRedirectURL: updatedSettings.DingTalkConnectRedirectURL,
DingTalkConnectCorpRestrictionPolicy: updatedSettings.DingTalkConnectCorpRestrictionPolicy,
DingTalkConnectInternalCorpID: updatedSettings.DingTalkConnectInternalCorpID,
DingTalkConnectBypassRegistration: updatedSettings.DingTalkConnectBypassRegistration,
DingTalkConnectSyncCorpEmail: updatedSettings.DingTalkConnectSyncCorpEmail,
DingTalkConnectSyncDisplayName: updatedSettings.DingTalkConnectSyncDisplayName,
DingTalkConnectSyncDept: updatedSettings.DingTalkConnectSyncDept,
DingTalkConnectSyncCorpEmailAttrKey: updatedSettings.DingTalkConnectSyncCorpEmailAttrKey,
DingTalkConnectSyncDisplayNameAttrKey: updatedSettings.DingTalkConnectSyncDisplayNameAttrKey,
DingTalkConnectSyncDeptAttrKey: updatedSettings.DingTalkConnectSyncDeptAttrKey,
DingTalkConnectSyncCorpEmailAttrName: updatedSettings.DingTalkConnectSyncCorpEmailAttrName,
DingTalkConnectSyncDisplayNameAttrName: updatedSettings.DingTalkConnectSyncDisplayNameAttrName,
DingTalkConnectSyncDeptAttrName: updatedSettings.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
@ -1822,6 +2000,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
// mapDingTalkValidateError maps ValidateDingTalkConfig errors to machine-readable reason codes.
func mapDingTalkValidateError(err error) string {
switch {
case errors.Is(err, config.ErrDingTalkV1AppTypeMismatch):
return "dingtalk_apptype_mismatch"
case errors.Is(err, config.ErrDingTalkV4InvalidAppKind):
return "dingtalk_app_kind_invalid"
default:
return "dingtalk_corp_config_invalid"
}
}
func hasPaymentFields(req UpdateSettingsRequest) bool {
return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
@ -1935,6 +2125,45 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
if before.DingTalkConnectEnabled != after.DingTalkConnectEnabled {
changed = append(changed, "dingtalk_connect_enabled")
}
if before.DingTalkConnectClientID != after.DingTalkConnectClientID {
changed = append(changed, "dingtalk_connect_client_id")
}
if req.DingTalkConnectClientSecret != "" {
changed = append(changed, "dingtalk_connect_client_secret")
}
if before.DingTalkConnectRedirectURL != after.DingTalkConnectRedirectURL {
changed = append(changed, "dingtalk_connect_redirect_url")
}
if before.DingTalkConnectCorpRestrictionPolicy != after.DingTalkConnectCorpRestrictionPolicy {
changed = append(changed, "dingtalk_connect_corp_restriction_policy")
}
if before.DingTalkConnectInternalCorpID != after.DingTalkConnectInternalCorpID {
changed = append(changed, "dingtalk_connect_internal_corp_id")
}
if before.DingTalkConnectBypassRegistration != after.DingTalkConnectBypassRegistration {
changed = append(changed, "dingtalk_connect_bypass_registration")
}
if before.DingTalkConnectSyncCorpEmail != after.DingTalkConnectSyncCorpEmail {
changed = append(changed, "dingtalk_connect_sync_corp_email")
}
if before.DingTalkConnectSyncDisplayName != after.DingTalkConnectSyncDisplayName {
changed = append(changed, "dingtalk_connect_sync_display_name")
}
if before.DingTalkConnectSyncDept != after.DingTalkConnectSyncDept {
changed = append(changed, "dingtalk_connect_sync_dept")
}
if before.DingTalkConnectSyncCorpEmailAttrKey != after.DingTalkConnectSyncCorpEmailAttrKey {
changed = append(changed, "dingtalk_connect_sync_corp_email_attr_key")
}
if before.DingTalkConnectSyncDisplayNameAttrKey != after.DingTalkConnectSyncDisplayNameAttrKey {
changed = append(changed, "dingtalk_connect_sync_display_name_attr_key")
}
if before.DingTalkConnectSyncDeptAttrKey != after.DingTalkConnectSyncDeptAttrKey {
changed = append(changed, "dingtalk_connect_sync_dept_attr_key")
}
if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
changed = append(changed, "wechat_connect_enabled")
}
@ -2246,6 +2475,7 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
{name: "wechat", before: before.WeChat, after: after.WeChat},
{name: "github", before: before.GitHub, after: after.GitHub},
{name: "google", before: before.Google, after: after.Google},
{name: "dingtalk", before: before.DingTalk, after: after.DingTalk},
}
for _, field := range fields {
if field.before.Balance != field.after.Balance {
@ -2350,6 +2580,11 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
data["auth_source_default_dingtalk_balance"] = authSourceDefaults.DingTalk.Balance
data["auth_source_default_dingtalk_concurrency"] = authSourceDefaults.DingTalk.Concurrency
data["auth_source_default_dingtalk_subscriptions"] = authSourceDefaults.DingTalk.Subscriptions
data["auth_source_default_dingtalk_grant_on_signup"] = authSourceDefaults.DingTalk.GrantOnSignup
data["auth_source_default_dingtalk_grant_on_first_bind"] = authSourceDefaults.DingTalk.GrantOnFirstBind
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
@ -3044,3 +3279,56 @@ func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
}
response.Success(c, result)
}
// ensureDingTalkSyncAttributes 在保存 settings 后,按 admin 配置的 (attr key, attr name)
// 兜底 upsert 对应 user attribute definition不存在则创建存在但 name 不同则更新 name
// type/options/required 不变)。仅 internal_only + 对应 sync 开关开启时执行。
// 失败仅记录日志,不阻塞 settings 保存。
func (h *SettingHandler) ensureDingTalkSyncAttributes(ctx context.Context, settings *service.SystemSettings) {
if h.userAttributeService == nil || settings == nil {
return
}
if settings.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
return
}
if settings.DingTalkConnectSyncDisplayName {
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDisplayNameAttrKey, settings.DingTalkConnectSyncDisplayNameAttrName, "钉钉 internal_only 登录时同步的钉钉姓名", service.AttributeTypeText)
}
if settings.DingTalkConnectSyncCorpEmail {
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncCorpEmailAttrKey, settings.DingTalkConnectSyncCorpEmailAttrName, "钉钉 internal_only 登录时同步的企业邮箱", service.AttributeTypeEmail)
}
if settings.DingTalkConnectSyncDept {
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDeptAttrKey, settings.DingTalkConnectSyncDeptAttrName, "钉钉 internal_only 登录时同步的完整部门路径(如:公司/研发部)", service.AttributeTypeText)
}
}
func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, name, description string, attrType service.UserAttributeType) {
key = strings.TrimSpace(key)
if key == "" {
return
}
existing, err := h.userAttributeService.GetDefinitionByKey(ctx, key)
if err == nil && existing != nil {
if strings.TrimSpace(name) != "" && existing.Name != name {
if _, err := h.userAttributeService.UpdateDefinition(ctx, existing.ID, service.UpdateAttributeDefinitionInput{
Name: &name,
}); err != nil {
slog.Warn("dingtalk: update user attribute definition name failed", "key", key, "err", err.Error())
return
}
slog.Info("dingtalk: updated user attribute definition name", "key", key, "name", name)
}
return
}
if _, err := h.userAttributeService.CreateDefinition(ctx, service.CreateAttributeDefinitionInput{
Key: key,
Name: name,
Description: description,
Type: attrType,
Enabled: true,
}); err != nil {
slog.Warn("dingtalk: ensure user attribute definition failed", "key", key, "err", err.Error())
return
}
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
}

View File

@ -137,7 +137,7 @@ func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
@ -174,7 +174,7 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
@ -214,7 +214,7 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -264,7 +264,7 @@ func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodS
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": false,
@ -309,7 +309,7 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -388,7 +388,7 @@ func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaul
ClockSkewSeconds: 120,
},
})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -417,7 +417,7 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -450,7 +450,7 @@ func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAu
err: errors.New("write auth source defaults failed"),
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,

View File

@ -0,0 +1,319 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// dingtalkSettingsRepoStub 复用 settingHandlerRepoStub已在 setting_handler_auth_source_defaults_test.go 定义)
func newDingTalkSettingsHandler() (*SettingHandler, *settingHandlerRepoStub) {
repo := &settingHandlerRepoStub{values: map[string]string{}}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
return handler, repo
}
// baseValidDingTalkBody 返回一个可以通过所有校验的最小合法 body。
func baseValidDingTalkBody() map[string]any {
return map[string]any{
"dingtalk_connect_enabled": true,
"dingtalk_connect_client_id": "test-client-id",
"dingtalk_connect_client_secret": "test-client-secret",
"dingtalk_connect_redirect_url": "https://example.com/auth/dingtalk/callback",
"dingtalk_connect_corp_restriction_policy": "none",
}
}
// TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A
// internal_only + internal_corp_id="" 应通过校验(→ 200不再是 400。
func TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_internal_corp_id"] = "" // 空值现在合法
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// TestSettingsPUT_DingTalk_HappyPath_None 验证 none policy → 200
func TestSettingsPUT_DingTalk_HappyPath_None(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "none"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["dingtalk_connect_enabled"])
}
// TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID 验证 internal_only + corp_id → 200
func TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_internal_corp_id"] = "ding-corp-123"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip 验证 bypass_registration 字段 save+load。
// 必须用 policy=internal_onlybypass 仅在该 policy 下生效,其它 policy 写入层会 coerce 为 false。
func TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_bypass_registration"] = true
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["dingtalk_connect_bypass_registration"])
}
// TestSettingsPUT_DingTalk_Disabled_SkipsValidation 验证 disabled 时跳过 corp 校验 → 200。
// 用 enabled=true 时必然触发"Client ID is required when enabled"的空 client_id 作为
// 哨兵——只要 enabled=false 仍能 200 就证明跳过了。
func TestSettingsPUT_DingTalk_Disabled_SkipsValidation(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := map[string]any{
"dingtalk_connect_enabled": false,
"dingtalk_connect_client_id": "", // 这种空值在 enabled=true 时会被 400 拒绝
"dingtalk_connect_corp_restriction_policy": "internal_only",
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip 验证三个 sync 开关在 internal_only 下可正常 save+load。
func TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_sync_corp_email"] = true
body["dingtalk_connect_sync_display_name"] = true
body["dingtalk_connect_sync_dept"] = true
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["dingtalk_connect_sync_corp_email"], "sync_corp_email should be true for internal_only")
require.Equal(t, true, data["dingtalk_connect_sync_display_name"], "sync_display_name should be true for internal_only")
require.Equal(t, true, data["dingtalk_connect_sync_dept"], "sync_dept should be true for internal_only")
}
// TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse 验证 policy=none 时三个 sync 开关被 coerce 为 false。
func TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "none"
body["dingtalk_connect_sync_corp_email"] = true
body["dingtalk_connect_sync_display_name"] = true
body["dingtalk_connect_sync_dept"] = true
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["dingtalk_connect_sync_corp_email"], "sync_corp_email must be coerced to false when policy=none")
require.Equal(t, false, data["dingtalk_connect_sync_display_name"], "sync_display_name must be coerced to false when policy=none")
require.Equal(t, false, data["dingtalk_connect_sync_dept"], "sync_dept must be coerced to false when policy=none")
}
// TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone 验证升级兼容:
// admin 直接把 corp_restriction_policy=whitelist 提交(前端 UI 已无此选项,但 API 仍可命中)
// 不应导致 400 失败,应该被静默 coerce 为 none 后通过校验。
func TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, repo := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "whitelist"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "none", repo.values[service.SettingKeyDingTalkConnectCorpRestrictionPolicy],
"stale whitelist 应在写入路径被 coerce 为 none")
}
// TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip 验证 3 个 attr key 字段 save+load + 空值 fallback 到默认值。
func TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("custom_attr_keys_saved", func(t *testing.T) {
handler, repo := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_sync_corp_email"] = true
body["dingtalk_connect_sync_display_name"] = true
body["dingtalk_connect_sync_dept"] = true
body["dingtalk_connect_sync_corp_email_attr_key"] = "my_email_attr"
body["dingtalk_connect_sync_display_name_attr_key"] = "my_name_attr"
body["dingtalk_connect_sync_dept_attr_key"] = "my_dept_attr"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证写入 DB 的 key
require.Equal(t, "my_email_attr", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
require.Equal(t, "my_name_attr", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
require.Equal(t, "my_dept_attr", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
// 验证响应中的 attr key
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, "my_email_attr", data["dingtalk_connect_sync_corp_email_attr_key"])
require.Equal(t, "my_name_attr", data["dingtalk_connect_sync_display_name_attr_key"])
require.Equal(t, "my_dept_attr", data["dingtalk_connect_sync_dept_attr_key"])
})
t.Run("empty_attr_keys_fallback_to_defaults", func(t *testing.T) {
handler, repo := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
// 不传 attr key → 写入层 fallback 到默认值
body["dingtalk_connect_sync_corp_email_attr_key"] = ""
body["dingtalk_connect_sync_display_name_attr_key"] = ""
body["dingtalk_connect_sync_dept_attr_key"] = ""
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
// 空值应 fallback 到默认值并持久化
require.Equal(t, "dingtalk_email", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
require.Equal(t, "dingtalk_name", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
require.Equal(t, "dingtalk_department", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
})
}

View File

@ -0,0 +1,398 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// dingTalkClientConfig 是 DingTalkClient 需要的最小配置子集
type dingTalkClientConfig struct {
ClientID string
ClientSecret string
TokenURL string
UserInfoURL string
}
type DingTalkClient struct {
cfg dingTalkClientConfig
appToken string
appTokenExp time.Time // 钉钉 7200s留 200s 余量 → 7000s
mu sync.Mutex
httpClient *http.Client
// TODO(multi-instance): Redis 集中缓存 appToken
}
type DingTalkUserTokenResp struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpireIn int64 `json:"expireIn"`
CorpID string `json:"corpId"`
}
func (c *DingTalkClient) ExchangeCodeForUserToken(ctx context.Context, code string) (*DingTalkUserTokenResp, error) {
body := map[string]string{
"clientId": c.cfg.ClientID,
"clientSecret": c.cfg.ClientSecret,
"code": code,
"grantType": "authorization_code",
}
payload, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
var out DingTalkUserTokenResp
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
if strings.TrimSpace(out.AccessToken) == "" {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
return &out, nil
}
type DingTalkAPIError struct {
Code string
Message string
HTTP int
}
func (e *DingTalkAPIError) Error() string {
return fmt.Sprintf("dingtalk api error code=%s msg=%s http=%d", e.Code, e.Message, e.HTTP)
}
func parseDingTalkErr(raw []byte, status int) error {
var v struct {
Code string `json:"code"`
Message string `json:"message"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
_ = json.Unmarshal(raw, &v)
code := v.Code
if code == "" && v.ErrCode != 0 {
code = fmt.Sprintf("%d", v.ErrCode)
}
msg := v.Message
if msg == "" {
msg = v.ErrMsg
}
return &DingTalkAPIError{Code: code, Message: msg, HTTP: status}
}
// GetUnionIdByUserToken 调用 /v1.0/contact/users/me 返回 unionId 与用户自设昵称 nick。
// nick 来自钉钉新版 OIDC 接口(用户在 App 个人资料填的昵称),与旧版 user/get.nickname 不同源。
func (c *DingTalkClient) GetUnionIdByUserToken(ctx context.Context, userToken string) (unionID string, nick string, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserInfoURL, nil)
if err != nil {
return "", "", err
}
req.Header.Set("x-acs-dingtalk-access-token", userToken)
resp, err := c.httpClient.Do(req)
if err != nil {
return "", "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", "", parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
UnionID string `json:"unionId"`
Nick string `json:"nick"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return "", "", err
}
if strings.TrimSpace(v.UnionID) == "" {
return "", "", parseDingTalkErr(raw, resp.StatusCode)
}
return v.UnionID, v.Nick, nil
}
type DingTalkStaffInfo struct {
UserID string
Name string // 企业内真实姓名(钉钉企业管理后台配置)
Nickname string // 钉钉个人昵称(用户自己设置)
Email string
DeptIDs []int64
// CorpID 不来自 staff 接口,来自 userToken不在此 struct
}
// dingTalkOAPIBase 推导钉钉旧版 OAPI base URLhost: api.dingtalk.com → oapi.dingtalk.com
// getbyunionid 与 topapi/v2/user/get 仅在旧版 OAPI 提供,不在 v1.0 OpenAPI。
func (c *DingTalkClient) dingTalkOAPIBase() string {
u, err := url.Parse(c.cfg.UserInfoURL)
if err != nil || u.Scheme == "" || u.Host == "" {
return "https://oapi.dingtalk.com"
}
host := u.Host
if strings.HasPrefix(host, "api.") {
host = "oapi." + strings.TrimPrefix(host, "api.")
}
return u.Scheme + "://" + host
}
func (c *DingTalkClient) GetAppToken(ctx context.Context) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.appToken != "" && time.Now().Before(c.appTokenExp) {
return c.appToken, nil
}
body := map[string]string{"appKey": c.cfg.ClientID, "appSecret": c.cfg.ClientSecret}
payload, _ := json.Marshal(body)
// 钉钉新版 v1.0 企业内部应用 access_token: POST /v1.0/oauth2/accessToken
// 此 token 也可作为旧版 OAPI 的 access_token 使用(钉钉文档已说明)
appTokenURL := strings.Replace(c.cfg.TokenURL, "/oauth2/userAccessToken", "/oauth2/accessToken", 1)
if !strings.Contains(appTokenURL, "accessToken") && !strings.Contains(appTokenURL, "gettoken") {
appTokenURL = c.cfg.TokenURL // fallback for test stub
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, appTokenURL, bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
AccessToken string `json:"accessToken"`
ExpireIn int64 `json:"expireIn"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return "", err
}
if v.AccessToken == "" {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
c.appToken = v.AccessToken
ttl := v.ExpireIn
if ttl > 200 {
ttl -= 200
}
c.appTokenExp = time.Now().Add(time.Duration(ttl) * time.Second)
return c.appToken, nil
}
func (c *DingTalkClient) GetUserIdByUnionId(ctx context.Context, unionID string) (string, error) {
appToken, err := c.GetAppToken(ctx)
if err != nil {
return "", err
}
body := map[string]string{"unionid": unionID}
payload, _ := json.Marshal(body)
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token=XXX
// access_token 通过 query string 传递(不是 header
var targetURL string
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
targetURL = c.dingTalkOAPIBase() + "/topapi/user/getbyunionid?access_token=" + url.QueryEscape(appToken)
} else {
targetURL = c.cfg.UserInfoURL // fallback for test stub
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
Result struct {
UserID string `json:"userid"`
} `json:"result"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return "", err
}
if v.ErrCode != 0 {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
if strings.TrimSpace(v.Result.UserID) == "" {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
return v.Result.UserID, nil
}
// DingTalkDeptInfo 部门信息topapi/v2/department/get 返回子集)
type DingTalkDeptInfo struct {
DeptID int64
Name string
ParentID int64
}
// GetDeptInfo 查询单个部门信息(用于递归拼部门路径)。
// 调用钉钉旧版 OAPI: POST /topapi/v2/department/get?access_token=XXX
func (c *DingTalkClient) GetDeptInfo(ctx context.Context, deptID int64) (*DingTalkDeptInfo, error) {
appToken, err := c.GetAppToken(ctx)
if err != nil {
return nil, err
}
body := map[string]any{"dept_id": deptID, "language": "zh_CN"}
payload, _ := json.Marshal(body)
var targetURL string
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/department/get?access_token=" + url.QueryEscape(appToken)
} else {
targetURL = c.cfg.UserInfoURL // test stub fallback
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
Result struct {
DeptID int64 `json:"dept_id"`
Name string `json:"name"`
ParentID int64 `json:"parent_id"`
} `json:"result"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, err
}
if v.ErrCode != 0 {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
return &DingTalkDeptInfo{
DeptID: v.Result.DeptID,
Name: v.Result.Name,
ParentID: v.Result.ParentID,
}, nil
}
func (c *DingTalkClient) GetStaffInfoByUserId(ctx context.Context, userID string) (*DingTalkStaffInfo, error) {
appToken, err := c.GetAppToken(ctx)
if err != nil {
return nil, err
}
body := map[string]string{"userid": userID}
payload, _ := json.Marshal(body)
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/v2/user/get?access_token=XXX
var targetURL string
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/user/get?access_token=" + url.QueryEscape(appToken)
} else {
targetURL = c.cfg.UserInfoURL // fallback for test stub
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
Result struct {
UserID string `json:"userid"`
Name string `json:"name"`
Nickname string `json:"nickname"`
Email string `json:"email"`
OrgEmail string `json:"org_email"`
Extension string `json:"extension"`
DeptID []int64 `json:"dept_id_list"`
} `json:"result"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, err
}
if v.ErrCode != 0 {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
if strings.TrimSpace(v.Result.UserID) == "" {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
// 邮箱三级 fallbackorg_email > email > extension["企业邮箱"]钉钉自定义扩展字段JSON string
email := strings.TrimSpace(v.Result.OrgEmail)
emailSource := "org_email"
if email == "" {
email = strings.TrimSpace(v.Result.Email)
emailSource = "email"
}
extensionParsed := false
if email == "" && strings.TrimSpace(v.Result.Extension) != "" {
var ext map[string]string
if err := json.Unmarshal([]byte(v.Result.Extension), &ext); err == nil {
extensionParsed = true
if v, ok := ext["企业邮箱"]; ok {
email = strings.TrimSpace(v)
emailSource = "extension.企业邮箱"
}
}
}
if email == "" {
emailSource = "none"
}
slog.Info("dingtalk staff fetched",
"userid", v.Result.UserID,
"name_present", v.Result.Name != "",
"nickname_present", v.Result.Nickname != "",
"name_eq_nickname", v.Result.Name != "" && v.Result.Name == v.Result.Nickname,
"email_present", v.Result.Email != "",
"org_email_present", v.Result.OrgEmail != "",
"extension_present", v.Result.Extension != "",
"extension_parsed", extensionParsed,
"email_source", emailSource,
"dept_count", len(v.Result.DeptID),
)
return &DingTalkStaffInfo{
UserID: v.Result.UserID,
Name: v.Result.Name,
Nickname: v.Result.Nickname,
Email: email,
DeptIDs: v.Result.DeptID,
}, nil
}

View File

@ -0,0 +1,143 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestDingTalkClient_ExchangeCodeForUserToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "POST", r.Method)
require.Equal(t, "/v1.0/oauth2/userAccessToken", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"USER_TOKEN_X","expireIn":7200,"refreshToken":"R","corpId":"dingABC"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{
ClientID: "k", ClientSecret: "s",
TokenURL: server.URL + "/v1.0/oauth2/userAccessToken",
},
httpClient: server.Client(),
}
resp, err := cli.ExchangeCodeForUserToken(context.Background(), "AUTH_CODE")
require.NoError(t, err)
require.Equal(t, "USER_TOKEN_X", resp.AccessToken)
require.Equal(t, "dingABC", resp.CorpID)
}
func TestDingTalkClient_GetUnionIdByUserToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "USER_TOKEN_X", r.Header.Get("x-acs-dingtalk-access-token"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"nick":"张三","unionId":"UID_AAA","openId":"OPEN","avatarUrl":"http://x"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/v1.0/contact/users/me"},
httpClient: server.Client(),
}
unionID, nick, err := cli.GetUnionIdByUserToken(context.Background(), "USER_TOKEN_X")
require.NoError(t, err)
require.Equal(t, "UID_AAA", unionID)
require.Equal(t, "张三", nick)
}
func TestDingTalkClient_GetAppToken_Cached(t *testing.T) {
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{ClientID: "k", ClientSecret: "s", TokenURL: server.URL + "/gettoken"},
httpClient: server.Client(),
}
t1, err := cli.GetAppToken(context.Background())
require.NoError(t, err)
t2, err := cli.GetAppToken(context.Background())
require.NoError(t, err)
require.Equal(t, t1, t2)
require.Equal(t, 1, callCount, "second call should hit cache")
}
func TestDingTalkClient_GetUserIdByUnionId_60011(t *testing.T) {
appTokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
}))
defer appTokenServer.Close()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":60011,"errmsg":"not in directory"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{TokenURL: appTokenServer.URL + "/gettoken"},
httpClient: server.Client(),
}
cli.appToken = "APP_TKN"
cli.appTokenExp = time.Now().Add(time.Hour)
cli.cfg.UserInfoURL = server.URL + "/v1.0/contact/users/byUnionId"
_, err := cli.GetUserIdByUnionId(context.Background(), "UID_AAA")
require.Error(t, err)
apiErr, ok := err.(*DingTalkAPIError)
require.True(t, ok)
require.Equal(t, "60011", apiErr.Code)
}
// TestDingTalkClient_GetDeptInfo_Success 验证 GetDeptInfo 正常情况返回部门信息。
func TestDingTalkClient_GetDeptInfo_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":0,"errmsg":"ok","result":{"dept_id":42,"name":"AI数据","parent_id":1}}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{
UserInfoURL: server.URL + "/stub", // 不含 /contact/users/me走 test stub 路径
},
httpClient: server.Client(),
}
cli.appToken = "APP_TKN"
cli.appTokenExp = time.Now().Add(time.Hour)
info, err := cli.GetDeptInfo(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, int64(42), info.DeptID)
require.Equal(t, "AI数据", info.Name)
require.Equal(t, int64(1), info.ParentID)
}
// TestDingTalkClient_GetDeptInfo_ErrCode60003 验证 errcode=60003部门不存在时返回错误。
func TestDingTalkClient_GetDeptInfo_ErrCode60003(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"dept not found"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
httpClient: server.Client(),
}
cli.appToken = "APP_TKN"
cli.appTokenExp = time.Now().Add(time.Hour)
_, err := cli.GetDeptInfo(context.Background(), 999)
require.Error(t, err)
apiErr, ok := err.(*DingTalkAPIError)
require.True(t, ok)
require.Equal(t, "60003", apiErr.Code)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,391 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDingTalkOAuthStart_Disabled は sentinel テスト。
// TODO(task-1.10): newTestAuthHandlerWithDingTalk helper が追加されたら t.Skip を外す。
func TestDingTalkOAuthStart_Disabled(t *testing.T) {
t.Skip("helper newTestAuthHandlerWithDingTalk added in Task 1.10; sentinel only")
}
// TestBuildDingTalkSyntheticEmail_UsesUnionID 验证合成邮箱种子使用 unionID。
func TestBuildDingTalkSyntheticEmail_UsesUnionID(t *testing.T) {
unionID := "union_AbCdEf123"
email := buildDingTalkSyntheticEmail(unionID)
want := "dingtalk-union_abcdef123@dingtalk-connect.invalid"
require.Equal(t, want, email)
// 确保结果都是小写(邮箱大小写不敏感,统一小写)
require.True(t, strings.ToLower(email) == email, "synthetic email should be all lowercase")
// 确保前缀正确
require.True(t, strings.HasPrefix(email, "dingtalk-"), "should have dingtalk- prefix")
// 确保后缀是合成邮箱域名
require.True(t, strings.HasSuffix(email, "@dingtalk-connect.invalid"), "should have reserved domain suffix")
}
// TestBuildDingTalkSyntheticEmail_TrimsSpace 验证 unionID 空白被修剪。
func TestBuildDingTalkSyntheticEmail_TrimsSpace(t *testing.T) {
email := buildDingTalkSyntheticEmail(" UID_XYZ ")
require.Equal(t, "dingtalk-uid_xyz@dingtalk-connect.invalid", email)
}
// TestBuildDingTalkUpstreamClaims_EmptyStaff 验证 staff 为空 struct跨组织降级路径
// - subject 等于 unionID与 identityKey.ProviderSubject 一致)
// - corp_user_id 为空字符串(跨组织时拿不到企业 userid
// - email/username 为空字符串
// B/C: Step 3/4 失败降级时 staff = &DingTalkStaffInfo{}claims 不应有 nil。
func TestBuildDingTalkUpstreamClaims_EmptyStaff(t *testing.T) {
staff := &DingTalkStaffInfo{}
claims := buildDingTalkUpstreamClaims(staff, "UNION_AAA", "CORP_X")
require.Equal(t, "", claims["email"])
require.Equal(t, "", claims["username"])
// 重构后 subject = unionID与 identityKey.ProviderSubject 保持一致)
require.Equal(t, "UNION_AAA", claims["subject"])
require.Equal(t, "", claims["corp_user_id"]) // 企业 userid 跨组织时为空
require.Equal(t, "UNION_AAA", claims["union_id"])
require.Equal(t, "CORP_X", claims["corp_id"])
}
// TestCheckDingTalkCorpAllowed_CrossOrgPolicy 验证 policy=none 时允许任意 corp。
// D: corp 校验提前后逻辑不变。
func TestCheckDingTalkCorpAllowed_CrossOrgPolicy(t *testing.T) {
cfg := config.DingTalkConnectConfig{CorpRestrictionPolicy: "none"}
assert.True(t, checkDingTalkCorpAllowed(cfg, "dingABC"), "policy=none should allow any corp")
assert.True(t, checkDingTalkCorpAllowed(cfg, ""), "policy=none should allow empty corp")
assert.True(t, checkDingTalkCorpAllowed(cfg, "foreign_corp"), "policy=none should allow foreign corp")
}
// TestCheckDingTalkCorpAllowed_InternalOnly 验证 policy=internal_only 时的 corp 校验语义(方案 A 修订)。
// 钉钉 userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)不返回 corpId 字段,
// 因此 checkDingTalkCorpAllowed 完全不校验 corpID由 step 3 GetUserIdByUnionId 做真实判定
// (跨企业用户会被钉钉错误码 60011/60121 拒绝mapDingTalkErrorCode 映射回 corp_rejected
func TestCheckDingTalkCorpAllowed_InternalOnly(t *testing.T) {
cfgWithCorpID := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "dingInternal",
}
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "dingInternal"), "internal_only: matching corpID allowed")
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "foreign_corp"), "internal_only: corpID 字段不再用于决策step 3 兜底")
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, ""), "internal_only: 空 corpID 也通过(钉钉部分授权场景不返回 corpId")
cfgNoCorpID := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "",
}
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, "dingAnyNonEmpty"), "internal_only + no InternalCorpID: 非空 corpID 通过")
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, ""), "internal_only + no InternalCorpID: 空 corpID 也通过")
}
// TestDecideDingTalkStep34Strategy_PolicyNone 验证 policy=none 时
// Step 3/4 失败应降级shouldFallback=true, isFatal=false
func TestDecideDingTalkStep34Strategy_PolicyNone(t *testing.T) {
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
shouldFallback, isFatal := decideDingTalkStep34Strategy("none", step3Err)
require.True(t, shouldFallback, "policy=none: step3 failure should trigger fallback")
require.False(t, isFatal, "policy=none: step3 failure should NOT be fatal")
}
// TestDecideDingTalkStep34Strategy_PolicyNoneEmpty 验证 policy="" 时行为与 "none" 相同。
func TestDecideDingTalkStep34Strategy_PolicyNoneEmpty(t *testing.T) {
stepErr := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
shouldFallback, isFatal := decideDingTalkStep34Strategy("", stepErr)
require.True(t, shouldFallback, "policy='': step failure should trigger fallback")
require.False(t, isFatal, "policy='': step failure should NOT be fatal")
}
// TestDecideDingTalkStep34Strategy_PolicyInternalOnly 验证 policy=internal_only 时
// Step 3/4 失败应 hard failisFatal=true
func TestDecideDingTalkStep34Strategy_PolicyInternalOnly(t *testing.T) {
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
shouldFallback, isFatal := decideDingTalkStep34Strategy("internal_only", step3Err)
require.False(t, shouldFallback, "policy=internal_only: should NOT fallback on step3 error")
require.True(t, isFatal, "policy=internal_only: step3 failure should be fatal")
}
// TestDecideDingTalkStep34Strategy_NoError 验证 stepErr=nil 时两个返回值均为 false。
func TestDecideDingTalkStep34Strategy_NoError(t *testing.T) {
for _, policy := range []string{"none", "internal_only", ""} {
shouldFallback, isFatal := decideDingTalkStep34Strategy(policy, nil)
require.False(t, shouldFallback, "no error should not trigger fallback (policy=%q)", policy)
require.False(t, isFatal, "no error should not be fatal (policy=%q)", policy)
}
}
// TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart 验证 username 为空时
// 退到 email local part@ 之前的部分)。
// E: CompleteDingTalkOAuthRegistration username fallback。
func TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart(t *testing.T) {
tests := []struct {
name string
email string
username string
wantUser string
wantValid bool
}{
{
name: "username empty, normal email → local part",
email: "dingtalk-uid123@dingtalk-connect.invalid",
username: "",
wantUser: "dingtalk-uid123",
wantValid: true,
},
{
name: "username already set → keep original",
email: "user@example.com",
username: "张三",
wantUser: "张三",
wantValid: true,
},
{
name: "username empty, no @ in email → use whole email",
email: "noemail",
username: "",
wantUser: "noemail",
wantValid: true,
},
{
name: "both empty → invalid",
email: "",
username: "",
wantUser: "",
wantValid: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
username := tc.username
email := tc.email
// 模拟 CompleteDingTalkOAuthRegistration 中的 fallback 逻辑
if username == "" {
if at := strings.Index(email, "@"); at > 0 {
username = email[:at]
} else {
username = email
}
}
isValid := email != "" && username != ""
require.Equal(t, tc.wantUser, username, fmt.Sprintf("username for email=%q", tc.email))
require.Equal(t, tc.wantValid, isValid, "validity check")
})
}
}
// TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID 验证重构后 subject = unionID
// 而非 staff.UserID与 identityKey.ProviderSubject 保持一致。
// §4.2: buildDingTalkUpstreamClaims subject 字段修正。
func TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "user123", Name: "张三", Email: "zhangsan@corp.com"}
claims := buildDingTalkUpstreamClaims(staff, "union456", "dingcorp789")
// 重构后 subject = unionID全局唯一与 identityKey.ProviderSubject 一致)
require.Equal(t, "union456", claims["subject"], "subject should equal unionID after refactor")
// 企业 userid 保留为独立字段,供 audit/debug 使用
require.Equal(t, "user123", claims["corp_user_id"], "corp_user_id should be staff.UserID")
// union_id 字段与 subject 相同(冗余保留,便于读取)
require.Equal(t, "union456", claims["union_id"])
require.Equal(t, "dingcorp789", claims["corp_id"])
require.Equal(t, "张三", claims["username"])
require.Equal(t, "zhangsan@corp.com", claims["email"])
}
// TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID 验证跨组织降级时
// corp_user_id 为空字符串(跨组织拿不到企业 useridsubject 仍为 unionID。
func TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID(t *testing.T) {
// 跨组织降级路径staff = &DingTalkStaffInfo{}(所有字段为零值)
staff := &DingTalkStaffInfo{}
claims := buildDingTalkUpstreamClaims(staff, "union_cross_org", "foreign_corp")
require.Equal(t, "union_cross_org", claims["subject"], "subject should still be unionID for cross-org users")
require.Equal(t, "", claims["corp_user_id"], "corp_user_id should be empty for cross-org fallback")
require.Equal(t, "", claims["email"])
require.Equal(t, "", claims["username"])
}
// TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims 验证首个 dept_id 被存入 claims。
func TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "u1", Name: "张三", Email: "a@b.com", DeptIDs: []int64{42, 99}}
claims := buildDingTalkUpstreamClaims(staff, "uid1", "corpX")
// 只取首个 dept_id
require.Equal(t, int64(42), claims["primary_dept_id"], "primary_dept_id should be the first dept_id")
}
// TestBuildDingTalkUpstreamClaims_NoDeptIDs 验证无部门时 primary_dept_id=0。
func TestBuildDingTalkUpstreamClaims_NoDeptIDs(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "u2", Name: "李四"}
claims := buildDingTalkUpstreamClaims(staff, "uid2", "corpY")
require.Equal(t, int64(0), claims["primary_dept_id"], "primary_dept_id should be 0 when no depts")
}
// TestDingTalkStaffFromClaims_RoundTrip 验证 dingTalkStaffFromClaims 能从 claims 恢复 staff 信息。
func TestDingTalkStaffFromClaims_RoundTrip(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "u3", Name: "王五", Email: "ww@corp.com", DeptIDs: []int64{55}}
claims := buildDingTalkUpstreamClaims(staff, "uid3", "corpZ")
recovered := dingTalkStaffFromClaims(claims)
require.Equal(t, "王五", recovered.Name)
require.Equal(t, "ww@corp.com", recovered.Email)
require.Equal(t, "u3", recovered.UserID)
require.Equal(t, []int64{55}, recovered.DeptIDs)
}
// TestResolveDingTalkDeptPath_SingleLevel 验证单层部门parent_id=1返回部门名。
func TestResolveDingTalkDeptPath_SingleLevel(t *testing.T) {
handler := &AuthHandler{}
callCount := 0
responses := map[string]string{
"42": `{"errcode":0,"result":{"dept_id":42,"name":"研发部","parent_id":1}}`,
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
var req struct {
DeptID int64 `json:"dept_id"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
w.Header().Set("Content-Type", "application/json")
if resp, ok := responses[fmt.Sprintf("%d", req.DeptID)]; ok {
_, _ = w.Write([]byte(resp))
} else {
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
}
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
httpClient: server.Client(),
}
cli.appToken = "tok"
cli.appTokenExp = time.Now().Add(time.Hour)
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
require.NoError(t, err)
require.Equal(t, "研发部", path)
require.Equal(t, 2, callCount)
}
// TestSyncDingTalkIdentity_UsesCfgAttrKeys 验证 syncDingTalkIdentity 使用 cfg 中配置的 attr key
// 而不是硬编码值。通过 userAttributeService=nil 使同步路径走 warn 跳过,但在此之前先验证
// syncField 构建逻辑(即 attr key 从 cfg 读取)。
// 间接验证:通过构造定制 cfg确认不同 attr key 可以正确传入(编译时保证类型正确,运行时不 panic
func TestSyncDingTalkIdentity_UsesCfgAttrKeys_NoopWithNilService(t *testing.T) {
handler := &AuthHandler{
userAttributeService: nil, // nil → 触发 warn 跳过,但不 panic
}
cfg := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
SyncCorpEmail: true,
SyncDisplayName: true,
SyncDept: true,
// 自定义 attr key非默认值
SyncCorpEmailAttrKey: "custom_email_key",
SyncDisplayNameAttrKey: "custom_name_key",
SyncDeptAttrKey: "custom_dept_key",
}
staff := &DingTalkStaffInfo{
Name: "张三",
Email: "zhangsan@example.com",
}
// 调用不应 panicuserAttributeService 为 nil 时走 warn 跳过路径)
require.NotPanics(t, func() {
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 42, staff, false)
})
}
// TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService 验证 cfg 默认 attr key 为空时
// 使用 fallback 默认值dingtalk_email / dingtalk_name / dingtalk_department
// 此测试主要验证调用路径不 panic实际 key 赋值默认值的逻辑在 GetDingTalkConnectOAuthConfig 层。
func TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService(t *testing.T) {
handler := &AuthHandler{
userAttributeService: nil,
}
cfg := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
SyncCorpEmail: true,
SyncDisplayName: true,
SyncDept: false,
// 不设置 attr key等同于 GetDingTalkConnectOAuthConfig 未设置时 fallback 后的默认值已在调用前填充)
SyncCorpEmailAttrKey: "dingtalk_email",
SyncDisplayNameAttrKey: "dingtalk_name",
SyncDeptAttrKey: "dingtalk_department",
}
staff := &DingTalkStaffInfo{
Name: "李四",
Email: "lisi@corp.com",
}
require.NotPanics(t, func() {
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 99, staff, false)
})
}
// TestResolveDingTalkDeptPath_MultiLevel 验证多层部门路径拼接。
func TestResolveDingTalkDeptPath_MultiLevel(t *testing.T) {
handler := &AuthHandler{}
// 模拟42(AI研发) → parent=10(研发部) → parent=1(根)
responses := map[string]string{
"42": `{"errcode":0,"result":{"dept_id":42,"name":"AI研发","parent_id":10}}`,
"10": `{"errcode":0,"result":{"dept_id":10,"name":"研发部","parent_id":1}}`,
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 解析请求 body 拿到 dept_id
var req struct {
DeptID int64 `json:"dept_id"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
key := fmt.Sprintf("%d", req.DeptID)
w.Header().Set("Content-Type", "application/json")
if resp, ok := responses[key]; ok {
_, _ = w.Write([]byte(resp))
} else {
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
}
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
httpClient: server.Client(),
}
cli.appToken = "tok"
cli.appTokenExp = time.Now().Add(time.Hour)
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
require.NoError(t, err)
require.Equal(t, "研发部/AI研发", path)
}

View File

@ -4,6 +4,7 @@ import (
"context"
"log/slog"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@ -18,25 +19,30 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
userAttributeService *service.UserAttributeService
dingTalkClientInstance *DingTalkClient
dingTalkClientMu sync.Mutex
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService, userAttributeService *service.UserAttributeService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
redeemService: redeemService,
totpService: totpService,
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
redeemService: redeemService,
totpService: totpService,
userAttributeService: userAttributeService,
}
}

View File

@ -350,7 +350,8 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
return nil, nil
}
@ -519,7 +520,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "linuxdo")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
@ -195,6 +196,14 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
},
})
if err != nil {
slog.Error("pending auth session create failed",
"intent", strings.TrimSpace(payload.Intent),
"provider_type", strings.TrimSpace(payload.Identity.ProviderType),
"provider_key", strings.TrimSpace(payload.Identity.ProviderKey),
"provider_subject_len", len(strings.TrimSpace(payload.Identity.ProviderSubject)),
"resolved_email_len", len(strings.TrimSpace(payload.ResolvedEmail)),
"has_target_user", payload.TargetUserID != nil,
"error", err.Error())
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
}
@ -266,6 +275,22 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
// pendingSessionRequiresEmailCompletion 判断 callback 写入的 completion payload 是否处于"补邮箱"状态。
// 钉钉跨组织/staff 邮箱缺失时进入此状态前端跳到补邮箱页exchange 不应走 adoption apply。
func pendingSessionRequiresEmailCompletion(payload map[string]any) bool {
if v, ok := payload["requires_email_completion"].(bool); ok && v {
return true
}
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "email_completion")
}
// pendingSessionRequiresBindLogin 判断 callback 写入的 completion payload 是否处于"必须绑定已有账户"状态。
// 钉钉 signupBlocked=true注册关 + 钉钉企业豁免关)时进入此状态:前端渲染 bind_login 表单,
// exchange 不应消费 session否则后续 /pending/bind-login 找不到 session。
func pendingSessionRequiresBindLogin(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required")
}
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
@ -1467,8 +1492,10 @@ func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
// 把多种 choice 别名归一为 oauthPendingChoiceStepbind_login_required 是独立终态
// (前端渲染 needsBindLogin 而非 needsChooser故不能并入归一化列表。
switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
case "choice", "choose_account_action", "choose_account", "choose", "email_required":
normalized["step"] = oauthPendingChoiceStep
}
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
@ -1594,6 +1621,8 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
// bindPendingOAuthLogin = 绑定已有账户登录,不动 users.username用户已有自己的名字
h.maybeSyncDingTalkAfterLogin(c.Request.Context(), session, user.ID)
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
response.InternalError(c, "Failed to generate token pair")
@ -1792,6 +1821,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
// createPendingOAuthAccount = 注册新账户,需要把钉钉昵称同步到 users.username 作为初始值
h.maybeSyncDingTalkAfterRegistration(c.Request.Context(), session, user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
@ -1893,6 +1924,14 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.Success(c, payload)
return
}
if pendingSessionRequiresEmailCompletion(payload) {
response.Success(c, payload)
return
}
if pendingSessionRequiresBindLogin(payload) {
response.Success(c, payload)
return
}
if !adoptionDecision.hasDecision() {
adoptionRequired, _ := payload["adoption_required"].(bool)
if adoptionRequired {

View File

@ -502,7 +502,8 @@ func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string)
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
return nil, nil
}
@ -666,7 +667,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "oidc")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "wechat")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -0,0 +1,67 @@
package dto
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestAccountFromServiceShallow_RedactsSensitiveCredentials(t *testing.T) {
src := &service.Account{
ID: 42,
Name: "demo",
Platform: "anthropic",
Type: "oauth",
Credentials: map[string]any{
"access_token": "at-secret",
"refresh_token": "rt-secret",
"id_token": "id-secret",
"api_key": "sk-secret",
"base_url": "https://api.example.com",
"model_mapping": map[string]any{"foo": "bar"},
},
}
got := AccountFromServiceShallow(src)
require.NotNil(t, got)
// 敏感键不在 Credentials 里
require.NotContains(t, got.Credentials, "access_token")
require.NotContains(t, got.Credentials, "refresh_token")
require.NotContains(t, got.Credentials, "id_token")
require.NotContains(t, got.Credentials, "api_key")
// 非敏感键保留
require.Equal(t, "https://api.example.com", got.Credentials["base_url"])
require.Equal(t, map[string]any{"foo": "bar"}, got.Credentials["model_mapping"])
// 状态 map 标记敏感键存在
require.True(t, got.CredentialsStatus["has_access_token"])
require.True(t, got.CredentialsStatus["has_refresh_token"])
require.True(t, got.CredentialsStatus["has_id_token"])
require.True(t, got.CredentialsStatus["has_api_key"])
// JSON 序列化校验:响应体里不会出现敏感子串
raw, err := json.Marshal(got)
require.NoError(t, err)
require.NotContains(t, string(raw), "rt-secret")
require.NotContains(t, string(raw), "at-secret")
require.NotContains(t, string(raw), "sk-secret")
require.NotContains(t, string(raw), "id-secret")
// 状态标识应序列化进 JSON
require.Contains(t, string(raw), "credentials_status")
require.Contains(t, string(raw), "has_refresh_token")
// 原始 service.Account 不应被改动
require.Equal(t, "rt-secret", src.Credentials["refresh_token"])
}
func TestAccountFromServiceShallow_NilCredentialsOmitsStatus(t *testing.T) {
src := &service.Account{ID: 1, Name: "n", Platform: "anthropic", Type: "oauth"}
got := AccountFromServiceShallow(src)
require.NotNil(t, got)
require.Nil(t, got.Credentials)
require.Nil(t, got.CredentialsStatus)
}

View File

@ -0,0 +1,44 @@
// Package dto provides data transfer objects for HTTP handlers.
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
// RedactCredentials 复制一份 in剥离 service.SensitiveCredentialKeys 列出的所有敏感子键,
// 并产出一个 has_<key> 状态 map 表示哪些敏感键存在且非零值。
//
// 输入 nil 时返回 nil, nil避免响应里出现空对象
// 不修改入参;调用方拿到的 out 可安全序列化进 JSON 返回前端。
func RedactCredentials(in map[string]any) (out map[string]any, status map[string]bool) {
if in == nil {
return nil, nil
}
out = make(map[string]any, len(in))
for k, v := range in {
if service.IsSensitiveCredentialKey(k) {
if isCredentialValuePresent(v) {
if status == nil {
status = make(map[string]bool, 4)
}
status["has_"+k] = true
}
continue
}
out[k] = v
}
return out, status
}
// isCredentialValuePresent 判断值是否"存在且非零"。空字符串、nil、false 均视为未配置;
// 其余非零类型(数字、对象、字符串等)视为已配置。
func isCredentialValuePresent(v any) bool {
switch x := v.(type) {
case nil:
return false
case string:
return x != ""
case bool:
return x
default:
return true
}
}

View File

@ -0,0 +1,97 @@
package dto
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestRedactCredentials_NilInput(t *testing.T) {
out, status := RedactCredentials(nil)
require.Nil(t, out)
require.Nil(t, status)
}
func TestRedactCredentials_StripsSensitiveKeysAndReportsStatus(t *testing.T) {
in := map[string]any{
"refresh_token": "rt-secret",
"access_token": "at-secret",
"api_key": "sk-secret",
"aws_secret_access_key": "aws-secret",
"service_account_json": map[string]any{"private_key": "..."},
"private_key": "raw-key",
// 非敏感
"base_url": "https://api.example.com",
"model_mapping": map[string]any{"foo": "bar"},
"project_id": "proj-1",
"expires_at": int64(123456),
}
out, status := RedactCredentials(in)
require.NotContains(t, out, "refresh_token")
require.NotContains(t, out, "access_token")
require.NotContains(t, out, "api_key")
require.NotContains(t, out, "aws_secret_access_key")
require.NotContains(t, out, "service_account_json")
require.NotContains(t, out, "private_key")
require.Equal(t, "https://api.example.com", out["base_url"])
require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
require.Equal(t, "proj-1", out["project_id"])
require.Equal(t, int64(123456), out["expires_at"])
require.True(t, status["has_refresh_token"])
require.True(t, status["has_access_token"])
require.True(t, status["has_api_key"])
require.True(t, status["has_aws_secret_access_key"])
require.True(t, status["has_service_account_json"])
require.True(t, status["has_private_key"])
// 状态 map 不应携带非敏感键的 has_*
require.NotContains(t, status, "has_base_url")
require.NotContains(t, status, "has_project_id")
}
func TestRedactCredentials_EmptyValuesNotMarkedPresent(t *testing.T) {
in := map[string]any{
"refresh_token": "",
"access_token": nil,
"api_key": false,
"id_token": "actual-id",
}
out, status := RedactCredentials(in)
require.Empty(t, out, "敏感键即使为空也不应出现在 redacted output")
require.False(t, status["has_refresh_token"])
require.False(t, status["has_access_token"])
require.False(t, status["has_api_key"])
require.True(t, status["has_id_token"])
}
func TestRedactCredentials_DoesNotMutateInput(t *testing.T) {
in := map[string]any{
"refresh_token": "secret",
"base_url": "x",
}
_, _ = RedactCredentials(in)
require.Equal(t, "secret", in["refresh_token"], "原始 map 不应被修改")
require.Equal(t, "x", in["base_url"])
}
func TestRedactCredentials_AllKnownSensitiveKeys(t *testing.T) {
keys := []string{
"access_token", "refresh_token", "id_token",
"api_key", "session_key", "cookie",
"aws_secret_access_key", "aws_session_token",
"service_account_json", "service_account", "private_key",
}
in := make(map[string]any, len(keys))
for _, k := range keys {
in[k] = "filled"
}
out, status := RedactCredentials(in)
require.Empty(t, out)
for _, k := range keys {
require.True(t, status["has_"+k], "key %s 应在 status 中标记为已配置", k)
}
}

View File

@ -198,13 +198,15 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
redactedCreds, credsStatus := RedactCredentials(a.Credentials)
out := &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Credentials: redactedCreds,
CredentialsStatus: credsStatus,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
@ -531,11 +533,15 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
CreatedAt: rc.CreatedAt,
ExpiresAt: rc.ExpiresAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
if rc.IsExpired() {
out.Status = service.StatusExpired
}
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
@ -600,6 +606,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
ImageInputSize: l.ImageInputSize,
ImageOutputSize: l.ImageOutputSize,
ImageSizeSource: l.ImageSizeSource,
ImageSizeBreakdown: l.ImageSizeBreakdown,
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,

View File

@ -148,6 +148,65 @@ func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *
require.Equal(t, "claude-3", adminDTO.Model)
}
func TestUsageLogFromService_IncludesImageBillingMetadataForUserAndAdmin(t *testing.T) {
t.Parallel()
imageSize := "4K"
inputSize := "1024x1024"
outputSize := "3840x2160"
source := "output"
log := &service.UsageLog{
RequestID: "req_image_metadata",
Model: "gpt-image-2",
ImageCount: 2,
ImageSize: &imageSize,
ImageInputSize: &inputSize,
ImageOutputSize: &outputSize,
ImageSizeSource: &source,
ImageSizeBreakdown: map[string]int{"4K": 2},
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
for _, got := range []*UsageLog{userDTO, &adminDTO.UsageLog} {
require.Equal(t, 2, got.ImageCount)
require.NotNil(t, got.ImageSize)
require.Equal(t, imageSize, *got.ImageSize)
require.NotNil(t, got.ImageInputSize)
require.Equal(t, inputSize, *got.ImageInputSize)
require.NotNil(t, got.ImageOutputSize)
require.Equal(t, outputSize, *got.ImageOutputSize)
require.NotNil(t, got.ImageSizeSource)
require.Equal(t, source, *got.ImageSizeSource)
require.Equal(t, map[string]int{"4K": 2}, got.ImageSizeBreakdown)
}
}
func TestUsageLogFromService_PreservesHistoricalMissingImageSize(t *testing.T) {
t.Parallel()
log := &service.UsageLog{
RequestID: "req_legacy_image_missing_size",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: nil,
}
dto := UsageLogFromService(log)
require.Equal(t, 1, dto.ImageCount)
require.Nil(t, dto.ImageSize)
require.Nil(t, dto.ImageInputSize)
require.Nil(t, dto.ImageOutputSize)
require.Nil(t, dto.ImageSizeSource)
require.Nil(t, dto.ImageSizeBreakdown)
body, err := json.Marshal(dto)
require.NoError(t, err)
require.Contains(t, string(body), `"image_size":null`)
require.NotContains(t, string(body), `"image_size":"2K"`)
}
func f64Ptr(value float64) *float64 {
return &value
}

View File

@ -56,6 +56,23 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
DingTalkConnectClientSecretConfigured bool `json:"dingtalk_connect_client_secret_configured"`
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
@ -260,6 +277,7 @@ type PublicSettings struct {
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`

View File

@ -149,25 +149,28 @@ type AdminGroup struct {
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
LoadFactor *int `json:"load_factor,omitempty"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
// Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥
// 的存在性通过 CredentialsStatushas_<key>)暴露,原始值不返回前端。
Credentials map[string]any `json:"credentials"`
CredentialsStatus map[string]bool `json:"credentials_status,omitempty"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
LoadFactor *int `json:"load_factor,omitempty"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
@ -335,6 +338,7 @@ type RedeemCode struct {
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
@ -400,9 +404,13 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms"`
// 图片生成字段
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
ImageInputSize *string `json:"image_input_size"`
ImageOutputSize *string `json:"image_output_size"`
ImageSizeSource *string `json:"image_size_source"`
ImageSizeBreakdown map[string]int `json:"image_size_breakdown"`
MediaType *string `json:"media_type"`
// User-Agent
UserAgent *string `json:"user_agent"`

View File

@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -325,6 +326,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", apiKey.GroupID),
@ -374,6 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
@ -566,6 +569,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", currentAPIKey.GroupID),
@ -626,6 +630,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
@ -946,8 +951,8 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform = forcedPlatform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
// Get available models from account configurations for the selected group platform.
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
if len(availableModels) > 0 {
// Build model list from whitelist
@ -968,7 +973,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
// Fallback to default models
if platform == "openai" {
if platform == service.PlatformOpenAI {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
@ -976,6 +981,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
if platform == service.PlatformGemini {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": geminicli.DefaultModels,
})
return
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,
@ -1312,6 +1325,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
return
}
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
@ -1542,6 +1560,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}

View File

@ -161,14 +161,26 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
groupPlatform := ""
if apiKey.Group != nil {
groupPlatform = apiKey.Group.Platform
}
selectionSessionHash := sessionHash
if groupPlatform == service.PlatformGemini && selectionSessionHash != "" {
selectionSessionHash = "gemini:" + selectionSessionHash
}
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
if groupPlatform == service.PlatformGemini {
fs = NewFailoverState(h.maxAccountSwitchesGemini, false)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@ -194,6 +206,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
@ -213,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
fs.FailedAccountIDs[account.ID] = struct{}{}
continue
}
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
var result *service.ForwardResult
if account.Platform == service.PlatformGemini {
if h.geminiCompatService == nil {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured")
if accountReleaseFunc != nil {
accountReleaseFunc()
}
return
}
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
} else {
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@ -302,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
return
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -174,6 +174,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@ -199,6 +200,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
@ -308,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
return
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -0,0 +1,136 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type gatewayModelsAccountRepoStub struct {
service.AccountRepository
byGroup map[int64][]service.Account
}
type gatewayModelsResponseForTest struct {
Object string `json:"object"`
Data []gatewayModelItemForTest `json:"data"`
}
type gatewayModelItemForTest struct {
ID string `json:"id"`
}
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, ok := s.byGroup[groupID]
if !ok {
return nil, nil
}
out := make([]service.Account, len(accounts))
copy(out, accounts)
return out, nil
}
func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
return &GatewayHandler{
gatewayService: service.NewGatewayService(
repo,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
),
}
}
func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(20)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformGemini},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, "list", got.Object)
require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
}
func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(21)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4-6",
},
},
},
{
ID: 2,
Platform: service.PlatformGemini,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-2.5-flash": "gemini-2.5-flash",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
}
func modelIDsForTest(models []gatewayModelItemForTest) []string {
ids := make([]string, 0, len(models))
for _, model := range models {
ids = append(ids, model.ID)
}
return ids
}

View File

@ -61,6 +61,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@ -113,6 +114,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@ -372,6 +374,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@ -419,6 +422,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}

View File

@ -143,6 +143,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
} else {
@ -155,6 +156,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
}
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@ -176,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
forwardDurationMs := time.Since(forwardStart).Milliseconds()
@ -201,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
@ -292,7 +299,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
// has been probed to not support the Responses API, the request is
// is forced or probed to not support the Responses API, the request is
// forwarded directly to /v1/chat/completions — not through the default
// CC→Responses conversion path.
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {

View File

@ -282,6 +282,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
return
@ -297,6 +298,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@ -330,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@ -354,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
@ -677,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
)
if len(failedAccountIDs) == 0 {
if err != nil {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@ -690,6 +698,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
}
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@ -992,6 +1001,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
reqLog *zap.Logger,
) (func(), bool) {
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
@ -1002,6 +1012,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
}
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
@ -1598,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
return
}
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {

View File

@ -157,6 +157,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
return
}
@ -168,6 +169,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
return
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"runtime"
"runtime/debug"
@ -22,10 +23,11 @@ import (
)
const (
opsModelKey = "ops_model"
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
opsModelKey = "ops_model"
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
opsRoutingCapacityLimitedKey = "ops_routing_capacity_limited"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
@ -45,6 +47,8 @@ const (
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
opsCodeInvalidAPIKey = "INVALID_API_KEY"
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
)
const (
@ -393,6 +397,42 @@ func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string)
}
}
func markOpsRoutingCapacityLimited(c *gin.Context) {
if c == nil {
return
}
c.Set(opsRoutingCapacityLimitedKey, true)
}
func markOpsRoutingCapacityLimitedIfNoAvailable(c *gin.Context, err error) {
if !isOpsNoAvailableAccountError(err) {
return
}
markOpsRoutingCapacityLimited(c)
}
func isOpsRoutingCapacityLimited(c *gin.Context) bool {
if c == nil {
return false
}
v, ok := c.Get(opsRoutingCapacityLimitedKey)
if !ok {
return false
}
marked, _ := v.(bool)
return marked
}
func isOpsNoAvailableAccountError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, service.ErrNoAvailableAccounts) || errors.Is(err, service.ErrNoAvailableCompactAccounts) {
return true
}
return isOpsNoAvailableAccountMessage(err.Error())
}
type opsCaptureWriter struct {
gin.ResponseWriter
limit int
@ -775,11 +815,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, normalizedType, parsed.Message, parsed.Code, status)
entry := &service.OpsInsertErrorLogInput{
RequestID: requestID,
@ -1114,6 +1150,9 @@ func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
if isOpsClientAuthError(code, msg) {
return "auth"
}
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "request"
@ -1134,7 +1173,7 @@ func classifyOpsPhase(errType, message, code string) string {
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, opsErrNoAvailableAccounts) {
if isOpsNoAvailableAccountMessage(msg) {
return "routing"
}
return "internal"
@ -1178,7 +1217,31 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
}
}
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status int) (phase string, isBusinessLimited bool, errorOwner string, errorSource string) {
phase = classifyOpsPhase(errType, message, code)
routingCapacityLimited := isOpsRoutingCapacityLimited(c)
clientBusinessLimited := service.HasOpsClientBusinessLimited(c)
upstreamError := hasOpsUpstreamErrorContext(c)
if upstreamError && !routingCapacityLimited {
phase = "upstream"
}
if clientBusinessLimited && !upstreamError && !routingCapacityLimited {
phase = "auth"
}
if routingCapacityLimited {
phase = "routing"
}
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
errorOwner = classifyOpsErrorOwner(phase, message)
errorSource = classifyOpsErrorSource(phase, message)
return phase, isBusinessLimited, errorOwner, errorSource
}
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string, localClientAuthError ...bool) bool {
if len(localClientAuthError) > 0 && localClientAuthError[0] {
return true
}
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
return true
@ -1195,6 +1258,47 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
return false
}
func isOpsClientAuthError(code string, msg string) bool {
switch strings.TrimSpace(code) {
case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
return true
}
return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
}
func hasOpsUpstreamErrorContext(c *gin.Context) bool {
if c == nil {
return false
}
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
switch code := v.(type) {
case int:
if code > 0 {
return true
}
case int64:
if code > 0 {
return true
}
}
}
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
return true
}
}
return false
}
func isOpsNoAvailableAccountMessage(message string) bool {
msg := strings.ToLower(message)
return strings.Contains(msg, opsErrNoAvailableAccounts) ||
strings.Contains(msg, "no available account") ||
strings.Contains(msg, "no available gemini accounts") ||
strings.Contains(msg, "no available openai accounts") ||
strings.Contains(msg, "no available compatible accounts")
}
func classifyOpsErrorOwner(phase string, message string) string {
// Standardized owners: client|provider|platform
switch phase {

View File

@ -275,6 +275,218 @@ func TestNormalizeOpsErrorType(t *testing.T) {
}
}
func TestClassifyOpsNoAvailableAccountsExcludedFromSLA(t *testing.T) {
const message = "No available accounts"
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
markOpsRoutingCapacityLimited(c)
errType := normalizeOpsErrorType("api_error", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
require.Equal(t, "api_error", errType)
require.Equal(t, "routing", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsRoutingCapacityMarkerExcludesMaskedSelectionFailureFromSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
markOpsRoutingCapacityLimited(c)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"Service temporarily unavailable",
"",
http.StatusServiceUnavailable,
)
require.Equal(t, "routing", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
tests := []struct {
name string
errType string
message string
code string
status int
}{
{
name: "standard invalid API key",
errType: "api_error",
message: "Invalid API key",
code: "INVALID_API_KEY",
status: http.StatusUnauthorized,
},
{
name: "standard missing API key",
errType: "api_error",
message: "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header",
code: "API_KEY_REQUIRED",
status: http.StatusUnauthorized,
},
{
name: "google invalid API key",
errType: "api_error",
message: "Invalid API key",
code: "401",
status: http.StatusUnauthorized,
},
{
name: "google missing API key",
errType: "api_error",
message: "API key is required",
code: "401",
status: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
errType := normalizeOpsErrorType(tt.errType, tt.code)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
require.Equal(t, "api_error", errType)
require.Equal(t, "auth", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "client", errorOwner)
require.Equal(t, "client_request", errorSource)
})
}
}
func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
errType := normalizeOpsErrorType("api_error", "ACCESS_DENIED")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Access denied", "ACCESS_DENIED", http.StatusForbidden)
require.Equal(t, "api_error", errType)
require.Equal(t, "auth", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "client", errorOwner)
require.Equal(t, "client_request", errorSource)
}
func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
errType := normalizeOpsErrorType("api_error", "INTERNAL_ERROR")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Failed to validate API key", "INTERNAL_ERROR", http.StatusInternalServerError)
require.Equal(t, "api_error", errType)
require.Equal(t, "internal", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsUnsupportedModelExcludedFromSLA(t *testing.T) {
tests := []string{
"No available accounts: no available accounts supporting model: made-up-model",
"No available accounts: no available OpenAI accounts supporting model: made-up-model",
"No available Gemini accounts: no available Gemini accounts supporting model: made-up-model",
"No available accounts: no available accounts supporting model: made-up-model (channel pricing restriction)",
}
for _, message := range tests {
t.Run(message, func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
markOpsRoutingCapacityLimited(c)
errType := normalizeOpsErrorType("api_error", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
require.Equal(t, "api_error", errType)
require.Equal(t, "routing", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
})
}
}
func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"No available accounts",
"",
http.StatusServiceUnavailable,
)
require.Equal(t, "routing", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"Invalid API key",
"401",
http.StatusUnauthorized,
)
require.Equal(t, "upstream", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "provider", errorOwner)
require.Equal(t, "upstream_http", errorSource)
}
func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.SetOpsUpstreamError(c, http.StatusServiceUnavailable, "No available accounts", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"No available accounts",
"",
http.StatusServiceUnavailable,
)
require.Equal(t, "upstream", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "provider", errorOwner)
require.Equal(t, "upstream_http", errorSource)
}
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@ -58,7 +58,7 @@ func TestResolvePageImagePath(t *testing.T) {
if !ok {
t.Fatal("expected direct image path to be accepted")
}
want := filepath.Join(base, "logo.png")
want := mustEvalSymlinks(t, filepath.Join(base, "logo.png"))
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
@ -67,7 +67,7 @@ func TestResolvePageImagePath(t *testing.T) {
if !ok {
t.Fatal("expected nested image path to be accepted")
}
want = filepath.Join(base, "images", "logo.png")
want = mustEvalSymlinks(t, filepath.Join(base, "images", "logo.png"))
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
@ -100,3 +100,13 @@ func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
t.Fatalf("expected symlink escape to be rejected, got %q", got)
}
}
func mustEvalSymlinks(t *testing.T, path string) string {
t.Helper()
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
t.Fatalf("eval symlinks for %q: %v", path, err)
}
return realPath
}

View File

@ -61,6 +61,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,

View File

@ -67,6 +67,7 @@ type userProfileResponse struct {
LinuxDoBound bool `json:"linuxdo_bound"`
OIDCBound bool `json:"oidc_bound"`
WeChatBound bool `json:"wechat_bound"`
DingTalkBound bool `json:"dingtalk_bound"`
}
type userProfileSourceContext struct {
@ -528,15 +529,17 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
DingTalkBound: identities.DingTalk.Bound,
}
}
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
return map[string]service.UserIdentitySummary{
"email": identities.Email,
"linuxdo": identities.LinuxDo,
"oidc": identities.OIDC,
"wechat": identities.WeChat,
"email": identities.Email,
"linuxdo": identities.LinuxDo,
"oidc": identities.OIDC,
"wechat": identities.WeChat,
"dingtalk": identities.DingTalk,
}
}
@ -585,7 +588,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
out := make([]service.UserIdentitySummary, 0, 3)
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat, identities.DingTalk} {
if summary.Bound {
out = append(out, summary)
}

View File

@ -105,10 +105,16 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
// CreatePayment creates an Alipay payment using the following routing:
// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
// - Desktop fallback: if precreate is unavailable for the merchant, fall back
// to alipay.trade.page.pay and expose both pay_url and qr_code so the
// frontend can render a QR while still allowing direct page open.
// - Desktop, default: prefer alipay.trade.precreate (FACE_TO_FACE_PAYMENT) to
// get a scannable QR payload. If precreate is unavailable for the merchant,
// fall back to alipay.trade.page.pay and expose pay_url only — the frontend
// opens the Alipay checkout in a new tab.
// - Desktop, paymentMode == "redirect": skip precreate and go straight to
// alipay.trade.page.pay so the frontend always opens the Alipay checkout
// in a new tab. Use this when the merchant has not enabled FACE_TO_FACE_PAYMENT.
//
// Note: alipay.trade.page.pay returns a checkout page URL, not a scannable
// payment QR. Never expose it via the QRCode field.
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
@ -150,6 +156,13 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
}
func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
// Explicit redirect mode: merchant opted into "always open the Alipay
// checkout page in a new tab" via the provider instance's payment_mode.
// Skip precreate to avoid a wasted API call.
if strings.EqualFold(strings.TrimSpace(a.config["paymentMode"]), "redirect") {
return a.createPagePayTrade(client, req, notifyURL, returnURL)
}
resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
if precreateErr == nil {
return resp, nil
@ -204,10 +217,12 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
// Only PayURL is exposed: alipay.trade.page.pay returns a checkout page URL
// that must be opened in a browser, not a scannable payment QR. Setting it
// as QRCode would let the frontend render an unscannable image.
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
QRCode: payURL.String(),
}, nil
}

View File

@ -189,8 +189,63 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
if resp.PayURL == "" {
t.Fatal("expected pay_url for desktop page pay")
}
if resp.QRCode != resp.PayURL {
t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
// page.pay returns a checkout page URL, not a scannable QR payload —
// it must never be exposed via QRCode (the frontend would render an
// unscannable image from it).
if resp.QRCode != "" {
t.Fatalf("qr_code = %q, want empty for page pay", resp.QRCode)
}
}
// When the provider instance is configured with paymentMode == "redirect",
// the desktop flow must skip precreate and go straight to page.pay.
func TestCreateTradeRedirectModeSkipsPrecreate(t *testing.T) {
origPreCreate := alipayTradePreCreate
origPagePay := alipayTradePagePay
t.Cleanup(func() {
alipayTradePreCreate = origPreCreate
alipayTradePagePay = origPagePay
})
preCreateCalls := 0
pagePayCalls := 0
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
preCreateCalls++
return &alipay.TradePreCreateRsp{
Error: alipay.Error{Code: alipay.CodeSuccess},
QRCode: "https://qr.alipay.example.com/precreate-token",
}, nil
}
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
pagePayCalls++
if param.ProductCode != alipayProductCodePagePay {
t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePagePay)
}
return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
}
provider := &Alipay{
config: map[string]string{"paymentMode": "redirect"},
}
resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
OrderID: "sub2_103",
Amount: "12.00",
Subject: "Balance recharge",
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if preCreateCalls != 0 {
t.Fatalf("precreate calls = %d, want 0 (redirect mode must skip precreate)", preCreateCalls)
}
if pagePayCalls != 1 {
t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
}
if resp.PayURL == "" {
t.Fatal("expected pay_url for redirect mode")
}
if resp.QRCode != "" {
t.Fatalf("qr_code = %q, want empty for redirect mode", resp.QRCode)
}
}

View File

@ -254,6 +254,8 @@ const (
proxyTLSHandshakeTimeout = 5 * time.Second
// clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body
clientTimeout = 10 * time.Second
// fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use.
fetchAvailableModelsBodyLimit int64 = 8 << 20
)
func NewClient(proxyURL string) (*Client, error) {
@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct {
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
// 支持 URL fallbacksandbox → daily → prod
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
if c == nil || c.httpClient == nil {
return nil, nil, errors.New("antigravity client is not configured")
}
reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 固定顺序prod -> daily
availableURLs := BaseURLs
fetchClient := c.fetchAvailableModelsHTTPClient()
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:fetchAvailableModels"
@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
resp, err := c.httpClient.Do(req)
resp, err := fetchClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, lastErr
}
respBodyBytes, err := io.ReadAll(resp.Body)
respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1))
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit {
return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit)
}
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, lastErr
}
func (c *Client) fetchAvailableModelsHTTPClient() *http.Client {
fetchClient := *c.httpClient
fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect
return &fetchClient
}
func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
if req == nil || req.URL == nil {
return errors.New("redirect url is nil")
}
if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) {
return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname())
}
return nil
}
func isAllowedFetchAvailableModelsRedirectHost(host string) bool {
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
return false
}
for _, baseURL := range BaseURLs {
parsed, err := url.Parse(baseURL)
if err != nil {
continue
}
if strings.EqualFold(host, parsed.Hostname()) {
return true
}
}
return false
}
// ── Privacy API ──────────────────────────────────────────────────────
// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)

View File

@ -225,6 +225,41 @@ func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testi
assert.Equal(t, "Describe this", parts[0].Text)
}
func TestChatCompletionsToResponses_EmptyContentNeverNull(t *testing.T) {
// Regression for #2515: the upstream Responses API rejects an input item
// whose content field is JSON null. Any chat-completions message that
// yields no usable content parts must serialize content as a string.
cases := []struct {
name string
content json.RawMessage
}{
{"null content", json.RawMessage(`null`)},
{"empty array content", json.RawMessage(`[]`)},
{"only empty text part", json.RawMessage(`[{"type":"text","text":""}]`)},
{"only empty base64 image part", json.RawMessage(`[{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`)},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-5.5",
Messages: []ChatMessage{
{Role: "user", Content: tc.content},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.NotContains(t, string(resp.Input), `"content":null`,
"converted input must not contain a null content field")
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
assert.Equal(t, `""`, string(items[0].Content),
"content must be an empty string, not null")
})
}
}
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",

View File

@ -339,7 +339,14 @@ func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error
if content.Text != nil {
return json.Marshal(*content.Text)
}
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
parts := convertChatContentPartsToResponses(content.Parts)
if len(parts) == 0 {
// A nil slice marshals to JSON null, which the upstream Responses API
// rejects ("expected an array of objects or string, but got null").
// Fall back to an empty string when no usable parts remain.
return json.Marshal("")
}
return json.Marshal(parts)
}
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {

View File

@ -306,6 +306,37 @@ type ResponsesUsage struct {
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
}
func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
type responsesUsageAlias ResponsesUsage
var aux struct {
responsesUsageAlias
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
PromptTokensDetails *ResponsesInputTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *ResponsesOutputTokensDetails `json:"completion_tokens_details,omitempty"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
*u = ResponsesUsage(aux.responsesUsageAlias)
if u.InputTokens == 0 && aux.PromptTokens != 0 {
u.InputTokens = aux.PromptTokens
}
if u.OutputTokens == 0 && aux.CompletionTokens != 0 {
u.OutputTokens = aux.CompletionTokens
}
if u.InputTokensDetails == nil && aux.PromptTokensDetails != nil {
u.InputTokensDetails = aux.PromptTokensDetails
}
if u.OutputTokensDetails == nil && aux.CompletionTokensDetails != nil {
u.OutputTokensDetails = aux.CompletionTokensDetails
}
if u.TotalTokens == 0 && (u.InputTokens != 0 || u.OutputTokens != 0) {
u.TotalTokens = u.InputTokens + u.OutputTokens
}
return nil
}
// ResponsesInputTokensDetails breaks down input token usage.
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`

View File

@ -17,7 +17,7 @@
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems
package openai_compat
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的有效支持状态。
//
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
type AccountResponsesSupport int
@ -35,11 +35,43 @@ const (
ResponsesSupportNo
)
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
// ResponsesSupportMode 描述账号级 Responses API 路由覆盖模式。
type ResponsesSupportMode string
const (
// ResponsesSupportModeAuto 表示跟随自动探测结果。
ResponsesSupportModeAuto ResponsesSupportMode = "auto"
// ResponsesSupportModeForceResponses 强制使用 /v1/responses。
ResponsesSupportModeForceResponses ResponsesSupportMode = "force_responses"
// ResponsesSupportModeForceChatCompletions 强制使用 /v1/chat/completions。
ResponsesSupportModeForceChatCompletions ResponsesSupportMode = "force_chat_completions"
)
// ExtraKeyResponsesMode 是 accounts.extra JSON 中存储手动覆盖模式的键名。
// 值类型为 stringauto=跟随探测force_responses=强制 Responses
// force_chat_completions=强制 Chat Completions。
const ExtraKeyResponsesMode = "openai_responses_mode"
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储自动探测结果的键名。
// 值类型为 booltrue=支持、false=不支持、键缺失=未探测。
const ExtraKeyResponsesSupported = "openai_responses_supported"
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
// NormalizeResponsesSupportMode 归一化账号级 Responses API 路由覆盖模式。
// 缺失或非法值按 auto 处理,以保持存量行为。
func NormalizeResponsesSupportMode(mode string) ResponsesSupportMode {
switch ResponsesSupportMode(mode) {
case ResponsesSupportModeForceResponses:
return ResponsesSupportModeForceResponses
case ResponsesSupportModeForceChatCompletions:
return ResponsesSupportModeForceChatCompletions
default:
return ResponsesSupportModeAuto
}
}
// ResolveResponsesSupport 从账号的 extra map 中读取手动覆盖模式与探测标记。
//
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI
@ -47,6 +79,14 @@ func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
if extra == nil {
return ResponsesSupportUnknown
}
if mode, ok := extra[ExtraKeyResponsesMode].(string); ok {
switch NormalizeResponsesSupportMode(mode) {
case ResponsesSupportModeForceResponses:
return ResponsesSupportYes
case ResponsesSupportModeForceChatCompletions:
return ResponsesSupportNo
}
}
v, ok := extra[ExtraKeyResponsesSupported]
if !ok {
return ResponsesSupportUnknown

View File

@ -16,6 +16,12 @@ func TestResolveResponsesSupport(t *testing.T) {
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
{"force responses", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses)}, ResponsesSupportYes},
{"force chat completions", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions)}, ResponsesSupportNo},
{"auto follows probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeAuto), ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
{"invalid mode follows probe", map[string]any{ExtraKeyResponsesMode: "bogus", ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
{"force responses overrides probe false", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, ResponsesSupportYes},
{"force chat completions overrides probe true", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, ResponsesSupportNo},
}
for _, tc := range tests {
@ -42,6 +48,10 @@ func TestShouldUseResponsesAPI(t *testing.T) {
// 已探测:标记决定
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
// 手动覆盖:覆盖自动探测结果
{"force responses overrides unsupported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, true},
{"force chat completions overrides supported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, false},
}
for _, tc := range tests {
@ -53,3 +63,26 @@ func TestShouldUseResponsesAPI(t *testing.T) {
})
}
}
func TestNormalizeResponsesSupportMode(t *testing.T) {
tests := []struct {
name string
mode string
want ResponsesSupportMode
}{
{"empty", "", ResponsesSupportModeAuto},
{"auto", "auto", ResponsesSupportModeAuto},
{"force responses", "force_responses", ResponsesSupportModeForceResponses},
{"force chat completions", "force_chat_completions", ResponsesSupportModeForceChatCompletions},
{"invalid", "enabled", ResponsesSupportModeAuto},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := NormalizeResponsesSupportMode(tc.mode)
if got != tc.want {
t.Errorf("NormalizeResponsesSupportMode(%q) = %q, want %q", tc.mode, got, tc.want)
}
})
}
}

View File

@ -230,6 +230,20 @@ type UserDashboardStats struct {
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
// 按"有效平台"维度拆分(与 ops 路径口径一致group.platform 优先,否则 account.platform
ByPlatform []PlatformDashboardStats `json:"by_platform,omitempty"`
}
// PlatformDashboardStats 单个平台的用量明细。
type PlatformDashboardStats struct {
Platform string `json:"platform"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
TotalActualCost float64 `json:"total_actual_cost"`
TodayRequests int64 `json:"today_requests"`
TodayTokens int64 `json:"today_tokens"`
TodayActualCost float64 `json:"today_actual_cost"`
}
// UsageLogFilters represents filters for usage log queries
@ -265,13 +279,22 @@ type UsageStats struct {
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
// PlatformUsage 表示某用户/某 API key 在单个"有效平台"维度的用量明细。
// Platform 取值与 ops 路径口径一致:优先 groups.platform否则 accounts.platform。
type PlatformUsage struct {
Platform string `json:"platform"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
ByPlatform []PlatformUsage `json:"by_platform,omitempty"`
}
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats struct {
APIKeyID int64 `json:"api_key_id"`

View File

@ -12,3 +12,14 @@ func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRel
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
}
}
func TestShouldEnqueueSchedulerOutboxForExtraUpdates_OpenAIResponsesCapabilityKeysAreRelevant(t *testing.T) {
updates := map[string]any{
"openai_responses_mode": "force_chat_completions",
"openai_responses_supported": false,
}
if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
t.Fatalf("expected responses capability updates to enqueue scheduler outbox")
}
}

View File

@ -204,7 +204,8 @@ func (r *announcementRepository) ListActive(ctx context.Context, now time.Time)
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
).
Order(dbent.Desc(announcement.FieldID))
Order(dbent.Desc(announcement.FieldID)).
Limit(200)
items, err := q.All(ctx)
if err != nil {

View File

@ -283,47 +283,90 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
}
func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
groups, err := q.
// 第一步:只查 ID + sort_order轻量不做分页 — 需要全量排序 account_count
rows, err := q.Clone().
Select(group.FieldID, group.FieldSortOrder).
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
type sortEntry struct {
id int64
sortOrder int
accountCount int64
}
entries := make([]sortEntry, 0, len(rows))
groupIDs := make([]int64, len(rows))
for i, r := range rows {
groupIDs[i] = r.ID
entries = append(entries, sortEntry{id: r.ID, sortOrder: r.SortOrder})
}
// 第二步:批量加载 account counts一次 SQL
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err != nil {
return nil, nil, err
}
for i := range outGroups {
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
for i := range entries {
c := counts[entries[i].id]
if c.Total > 0 {
entries[i].accountCount = c.Total
}
}
// 第三步Go 侧排序(数据量 = Group 总数,通常 < 200安全
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
sort.SliceStable(outGroups, func(i, j int) bool {
if outGroups[i].AccountCount == outGroups[j].AccountCount {
if outGroups[i].SortOrder == outGroups[j].SortOrder {
return outGroups[i].ID < outGroups[j].ID
}
return outGroups[i].SortOrder < outGroups[j].SortOrder
tieCmp := func(a, b sortEntry) bool {
if a.sortOrder == b.sortOrder {
return a.id < b.id
}
return a.sortOrder < b.sortOrder
}
sort.SliceStable(entries, func(i, j int) bool {
if entries[i].accountCount == entries[j].accountCount {
return tieCmp(entries[i], entries[j])
}
if sortOrder == pagination.SortOrderAsc {
return outGroups[i].AccountCount < outGroups[j].AccountCount
return entries[i].accountCount < entries[j].accountCount
}
return outGroups[i].AccountCount > outGroups[j].AccountCount
return entries[i].accountCount > entries[j].accountCount
})
return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
// 第四步:分页,只加载当前页需要的完整 Group。
page := paginateSlice(entries, params)
if len(page) == 0 {
return nil, paginationResultFromTotal(int64(total), params), nil
}
pageIDs := make([]int64, len(page))
pageIdx := make(map[int64]int, len(page))
for i, e := range page {
pageIDs[i] = e.id
pageIdx[e.id] = i
}
groups, err := r.client.Group.Query().
Where(group.IDIn(pageIDs...)).
All(ctx)
if err != nil {
return nil, nil, err
}
outGroups := make([]service.Group, len(page))
for i := range groups {
g := groupEntityToService(groups[i])
c := counts[g.ID]
g.AccountCount = c.Total
g.ActiveAccountCount = c.Active
g.RateLimitedAccountCount = c.RateLimited
if idx, ok := pageIdx[g.ID]; ok {
outGroups[idx] = *g
}
}
return outGroups, paginationResultFromTotal(int64(total), params), nil
}
func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {

View File

@ -44,6 +44,33 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
requireColumn(t, tx, "usage_logs", "image_input_size", "character varying", 32, true)
requireColumn(t, tx, "usage_logs", "image_output_size", "character varying", 32, true)
requireColumn(t, tx, "usage_logs", "image_size_source", "character varying", 16, true)
requireColumn(t, tx, "usage_logs", "image_size_breakdown", "jsonb", 0, true)
requireConstraintDefinitionContains(
t,
tx,
"usage_logs",
"usage_logs_image_size_source_check",
"image_size_source",
"'output'",
"'input'",
"'default'",
"'legacy'",
)
requireConstraintDefinitionContains(
t,
tx,
"usage_logs",
"usage_logs_image_billing_size_check",
"image_count",
"image_size IS NOT NULL",
"'1K'",
"'2K'",
"'4K'",
"'mixed'",
)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString

View File

@ -30,6 +30,7 @@ func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemC
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays).
SetNillableExpiresAt(code.ExpiresAt).
SetNillableUsedBy(code.UsedBy).
SetNillableUsedAt(code.UsedAt).
SetNillableGroupID(code.GroupID).
@ -56,6 +57,7 @@ func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays).
SetNillableExpiresAt(c.ExpiresAt).
SetNillableUsedBy(c.UsedBy).
SetNillableUsedAt(c.UsedAt).
SetNillableGroupID(c.GroupID)
@ -107,7 +109,28 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
q = q.Where(redeemcode.TypeEQ(codeType))
}
if status != "" {
q = q.Where(redeemcode.StatusEQ(status))
now := time.Now()
switch status {
case service.StatusExpired:
q = q.Where(redeemcode.Or(
redeemcode.StatusEQ(service.StatusExpired),
redeemcode.And(
redeemcode.StatusEQ(service.StatusUnused),
redeemcode.ExpiresAtNotNil(),
redeemcode.ExpiresAtLTE(now),
),
))
case service.StatusUnused:
q = q.Where(
redeemcode.StatusEQ(service.StatusUnused),
redeemcode.Or(
redeemcode.ExpiresAtIsNil(),
redeemcode.ExpiresAtGT(now),
),
)
default:
q = q.Where(redeemcode.StatusEQ(status))
}
}
if search != "" {
q = q.Where(
@ -158,6 +181,8 @@ func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Sele
field = redeemcode.FieldUsedAt
case "created_at":
field = redeemcode.FieldCreatedAt
case "expires_at":
field = redeemcode.FieldExpiresAt
case "code":
field = redeemcode.FieldCode
default:
@ -194,6 +219,11 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
} else {
up.ClearGroupID()
}
if code.ExpiresAt != nil {
up.SetExpiresAt(*code.ExpiresAt)
} else {
up.ClearExpiresAt()
}
updated, err := up.Save(ctx)
if err != nil {
@ -307,6 +337,7 @@ func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
UsedAt: m.UsedAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
ExpiresAt: m.ExpiresAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
}

View File

@ -51,11 +51,13 @@ func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
expiresAt := time.Now().UTC().Add(2 * time.Hour)
code := &service.RedeemCode{
Code: "TEST-CREATE",
Type: service.RedeemTypeBalance,
Value: 100,
Status: service.StatusUnused,
Code: "TEST-CREATE",
Type: service.RedeemTypeBalance,
Value: 100,
Status: service.StatusUnused,
ExpiresAt: &expiresAt,
}
err := s.repo.Create(s.ctx, code)
@ -65,6 +67,8 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("TEST-CREATE", got.Code)
s.Require().NotNil(got.ExpiresAt)
s.Require().WithinDuration(expiresAt, *got.ExpiresAt, time.Second)
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
@ -166,6 +170,23 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
s.Require().Equal(service.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_StatusExpiredByExpiresAt() {
past := time.Now().UTC().Add(-time.Hour)
future := time.Now().UTC().Add(time.Hour)
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-EXPIRED-BY-TIME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &past}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED-FUTURE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &future}))
expired, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusExpired, "")
s.Require().NoError(err)
s.Require().Len(expired, 1)
s.Require().Equal("STAT-EXPIRED-BY-TIME", expired[0].Code)
unused, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUnused, "")
s.Require().NoError(err)
s.Require().Len(unused, 1)
s.Require().Equal("STAT-UNUSED-FUTURE", unused[0].Code)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))

View File

@ -546,6 +546,8 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
"responses_websockets_v2_enabled",
"openai_ws_enabled",
"openai_ws_force_http",
"openai_responses_mode",
"openai_responses_supported",
}
filtered := make(map[string]any)
for _, key := range keys {

View File

@ -18,6 +18,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
"openai_ws_force_http": true,
"openai_responses_mode": "force_chat_completions",
"openai_responses_supported": false,
"mixed_scheduling": true,
"unused_large_field": "drop-me",
},
@ -28,6 +30,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"])
require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"])
require.Equal(t, true, got.Extra["openai_ws_force_http"])
require.Equal(t, "force_chat_completions", got.Extra["openai_responses_mode"])
require.Equal(t, false, got.Extra["openai_responses_supported"])
require.Equal(t, true, got.Extra["mixed_scheduling"])
require.Nil(t, got.Extra["unused_large_field"])
}

View File

@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, image_input_size, image_output_size, image_size_source, image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@ -73,6 +73,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // ip_address
"integer", // image_count
"text", // image_size
"text", // image_input_size
"text", // image_output_size
"text", // image_size_source
"jsonb", // image_size_breakdown
"text", // service_tier
"text", // reasoning_effort
"text", // inbound_endpoint
@ -92,6 +96,22 @@ const rawUsageLogModelColumn = "model"
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
// usageLogSuccessFilterUL 用于把"失败请求 usage log"tokens=0、cost=0、不计费的占位记录
// 从统计性聚合中排除,避免污染 Dashboard / 用量拆分等指标。
//
// schema 中没有 success bool 列;新增列要做迁移,风险大;这里用 actual_cost > 0 作为代理:
// 任何成功落账的请求都会产生 actual_cost包括 token 计费、纯图片 token 计费、按次/按图计费),
// 反之 failed-request usage log 的 actual_cost 为 0。
// 早期版本用 4 项 token 和 > 0 判定会把"按次/按图计费"与"image_output_tokens 独立计费"的纯图片
// 请求误判为失败,导致这部分请求从用量统计里消失,故改用 actual_cost。
// 配合 `FROM usage_logs ul` JOIN 查询使用。
const usageLogSuccessFilterUL = "ul.actual_cost > 0"
// usageLogEffectivePlatformExpr 用于按"有效平台"维度聚合 usage_logs
// 优先取请求实际走的分组 platform若分组未设置 platform 再 fallback 到 account.platform。
// 配套要求查询里 LEFT JOIN groups g ON g.id = ul.group_id 与 LEFT JOIN accounts a ON a.id = ul.account_id。
const usageLogEffectivePlatformExpr = "COALESCE(NULLIF(g.platform,''), a.platform)"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
@ -120,6 +140,24 @@ func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model
return conditions, args
}
func appendUsageLogBillingModeWhereCondition(conditions []string, args []any, billingMode string) ([]string, []any) {
mode := strings.TrimSpace(billingMode)
if mode == "" {
return conditions, args
}
placeholder := fmt.Sprintf("$%d", len(args)+1)
switch service.BillingMode(mode) {
case service.BillingModeImage:
conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR COALESCE(image_count, 0) > 0)", placeholder))
case service.BillingModeToken:
conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))", placeholder))
default:
conditions = append(conditions, fmt.Sprintf("billing_mode = %s", placeholder))
}
args = append(args, mode)
return conditions, args
}
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
@ -352,6 +390,10 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -369,7 +411,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@ -790,6 +832,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -803,7 +849,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*46)
args := make([]any, 0, len(keys)*50)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@ -867,6 +913,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -915,6 +965,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -1003,6 +1057,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -1016,7 +1074,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*46)
args := make([]any, 0, len(preparedList)*50)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@ -1077,6 +1135,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -1125,6 +1187,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -1181,6 +1247,10 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
ip_address,
image_count,
image_size,
image_input_size,
image_output_size,
image_size_source,
image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@ -1198,7 +1268,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@ -1225,6 +1295,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
imageInputSize := nullString(log.ImageInputSize)
imageOutputSize := nullString(log.ImageOutputSize)
imageSizeSource := nullString(log.ImageSizeSource)
imageSizeBreakdown := nullStringIntMapJSON(log.ImageSizeBreakdown)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
@ -1285,6 +1359,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
ipAddress,
log.ImageCount,
imageSize,
imageInputSize,
imageOutputSize,
imageSizeSource,
imageSizeBreakdown,
serviceTier,
reasoningEffort,
inboundEndpoint,
@ -2352,6 +2430,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats = usagestats.UserDashboardStats
// PlatformDashboardStats 单平台用量明细
type PlatformDashboardStats = usagestats.PlatformDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
stats := &UserDashboardStats{}
@ -2447,6 +2528,57 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
stats.Rpm = rpm
stats.Tpm = tpm
// 按"有效平台"维度拆分group.platform 优先,否则 account.platform
// 与 ops 路径口径一致HAVING 过滤掉无法确定平台的行(避免出现空字符串平台)。
// 与上面 totalStatsQuery/todayStatsQuery 的总值可能略微差异,原因有二:
// 1) 无平台归属的极少数行group/account 都没 platform会被 HAVING 排除;
// 2) usageLogSuccessFilterUL 会把 actual_cost = 0 的失败 placeholder 行排除,
// 而 totalStatsQuery/todayStatsQuery 没有这层过滤、会把这些行的 request 计数算进去。
platformQuery := `
SELECT
` + usageLogEffectivePlatformExpr + ` as platform,
COUNT(*) as total_requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.actual_cost), 0) as total_actual_cost,
COUNT(*) FILTER (WHERE ul.created_at >= $2) as today_requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) FILTER (WHERE ul.created_at >= $2), 0) as today_tokens,
COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2), 0) as today_actual_cost
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
LEFT JOIN accounts a ON a.id = ul.account_id
WHERE ul.user_id = $1
AND ` + usageLogSuccessFilterUL + `
GROUP BY ` + usageLogEffectivePlatformExpr + `
HAVING ` + usageLogEffectivePlatformExpr + ` IS NOT NULL AND ` + usageLogEffectivePlatformExpr + ` <> ''
ORDER BY total_actual_cost DESC
`
rows, err := r.sql.QueryContext(ctx, platformQuery, userID, today)
if err != nil {
return nil, err
}
for rows.Next() {
var p PlatformDashboardStats
if err := rows.Scan(
&p.Platform,
&p.TotalRequests,
&p.TotalTokens,
&p.TotalActualCost,
&p.TodayRequests,
&p.TodayTokens,
&p.TodayActualCost,
); err != nil {
_ = rows.Close()
return nil, err
}
stats.ByPlatform = append(stats.ByPlatform, p)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return stats, nil
}
@ -2662,10 +2794,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode)
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@ -2710,6 +2839,9 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// PlatformUsage represents per-platform usage breakdown
type PlatformUsage = usagestats.PlatformUsage
func normalizePositiveInt64IDs(ids []int64) []int64 {
if len(ids) == 0 {
return nil
@ -2750,15 +2882,21 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
result[id] = &BatchUserUsageStats{UserID: id}
}
// GROUP BY (user_id, effective_platform) 一次查询同时得到总值与按平台拆分。
// 应用层把同一 user_id 的多行累加为总值,并把非空 platform 行收集到 ByPlatform。
query := `
SELECT
user_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY user_id
ul.user_id,
` + usageLogEffectivePlatformExpr + ` as platform,
COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2 AND ul.created_at < $3), 0) as total_cost,
COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $4), 0) as today_cost
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
LEFT JOIN accounts a ON a.id = ul.account_id
WHERE ul.user_id = ANY($1)
AND ul.created_at >= LEAST($2, $4)
AND ` + usageLogSuccessFilterUL + `
GROUP BY ul.user_id, ` + usageLogEffectivePlatformExpr + `
`
today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
@ -2767,15 +2905,25 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
}
for rows.Next() {
var userID int64
var platform sql.NullString
var total float64
var todayTotal float64
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
if err := rows.Scan(&userID, &platform, &total, &todayTotal); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TotalActualCost = total
stats.TodayActualCost = todayTotal
stats, ok := result[userID]
if !ok {
continue
}
stats.TotalActualCost += total
stats.TodayActualCost += todayTotal
if platform.Valid && platform.String != "" {
stats.ByPlatform = append(stats.ByPlatform, PlatformUsage{
Platform: platform.String,
TotalActualCost: total,
TodayActualCost: todayTotal,
})
}
}
if err := rows.Close(); err != nil {
@ -3363,10 +3511,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.BillingMode != "" {
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
args = append(args, filters.BillingMode)
}
conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode)
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@ -4084,6 +4229,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
imageInputSize sql.NullString
imageOutputSize sql.NullString
imageSizeSource sql.NullString
imageSizeBreakdown sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
@ -4134,6 +4283,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress,
&imageCount,
&imageSize,
&imageInputSize,
&imageOutputSize,
&imageSizeSource,
&imageSizeBreakdown,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
@ -4212,6 +4365,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
if imageInputSize.Valid {
log.ImageInputSize = &imageInputSize.String
}
if imageOutputSize.Valid {
log.ImageOutputSize = &imageOutputSize.String
}
if imageSizeSource.Valid {
log.ImageSizeSource = &imageSizeSource.String
}
log.ImageSizeBreakdown = stringIntMapFromNullJSON(imageSizeBreakdown)
if serviceTier.Valid {
log.ServiceTier = &serviceTier.String
}
@ -4378,6 +4541,31 @@ func nullString(v *string) sql.NullString {
return sql.NullString{String: *v, Valid: true}
}
func nullStringIntMapJSON(v map[string]int) any {
if len(v) == 0 {
return nil
}
payload, err := json.Marshal(v)
if err != nil {
return nil
}
return string(payload)
}
func stringIntMapFromNullJSON(v sql.NullString) map[string]int {
if !v.Valid || strings.TrimSpace(v.String) == "" {
return nil
}
var out map[string]int
if err := json.Unmarshal([]byte(v.String), &out); err != nil {
return nil
}
if len(out) == 0 {
return nil
}
return out
}
func coalesceTrimmedString(v sql.NullString, fallback string) string {
if v.Valid && strings.TrimSpace(v.String) != "" {
return v.String

View File

@ -76,6 +76,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // image_input_size
sqlmock.AnyArg(), // image_output_size
sqlmock.AnyArg(), // image_size_source
sqlmock.AnyArg(), // image_size_breakdown
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
@ -155,6 +159,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
log.ImageCount,
sqlmock.AnyArg(),
sqlmock.AnyArg(), // image_input_size
sqlmock.AnyArg(), // image_output_size
sqlmock.AnyArg(), // image_size_source
sqlmock.AnyArg(), // image_size_breakdown
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
@ -230,12 +238,74 @@ func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
}
func TestPrepareUsageLogInsert_PersistsImageSizeMetadata(t *testing.T) {
imageSize := "4K"
inputSize := "1024x1024"
outputSize := "3840x2160"
source := "output"
prepared := prepareUsageLogInsert(&service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-image-metadata",
Model: "gpt-image-2",
RequestedModel: "gpt-image-2",
ImageCount: 2,
ImageSize: &imageSize,
ImageInputSize: &inputSize,
ImageOutputSize: &outputSize,
ImageSizeSource: &source,
ImageSizeBreakdown: map[string]int{"1K": 1, "4K": 1},
CreatedAt: time.Date(2025, 1, 6, 12, 0, 0, 0, time.UTC),
})
require.Equal(t, sql.NullString{String: imageSize, Valid: true}, prepared.args[34])
require.Equal(t, sql.NullString{String: inputSize, Valid: true}, prepared.args[35])
require.Equal(t, sql.NullString{String: outputSize, Valid: true}, prepared.args[36])
require.Equal(t, sql.NullString{String: source, Valid: true}, prepared.args[37])
breakdownJSON, ok := prepared.args[38].(string)
require.True(t, ok)
require.JSONEq(t, `{"1K":1,"4K":1}`, breakdownJSON)
}
func TestCoalesceTrimmedString(t *testing.T) {
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
}
func TestAppendUsageLogBillingModeWhereCondition(t *testing.T) {
tests := []struct {
name string
billingMode string
wantCondition string
}{
{
name: "image includes legacy image rows",
billingMode: string(service.BillingModeImage),
wantCondition: "(billing_mode = $1 OR COALESCE(image_count, 0) > 0)",
},
{
name: "token includes legacy non-image rows",
billingMode: string(service.BillingModeToken),
wantCondition: "(billing_mode = $1 OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))",
},
{
name: "per request remains exact",
billingMode: string(service.BillingModePerRequest),
wantCondition: "billing_mode = $1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conditions, args := appendUsageLogBillingModeWhereCondition(nil, nil, tt.billingMode)
require.Equal(t, []string{tt.wantCondition}, conditions)
require.Equal(t, []any{tt.billingMode}, args)
})
}
}
func anySliceToDriverValues(values []any) []driver.Value {
out := make([]driver.Value, 0, len(values))
for _, value := range values {
@ -528,6 +598,63 @@ func (s usageLogScannerStub) Scan(dest ...any) error {
}
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
t.Run("image_size_metadata_is_scanned", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(4),
int64(13),
int64(23),
int64(33),
sql.NullString{Valid: true, String: "req-image-metadata"},
"gpt-image-2",
sql.NullString{Valid: true, String: "gpt-image-2"},
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
0, 0, 0, 0, 0, 0,
0, 0.0, // image_output_tokens, image_output_cost
0.0, 0.0, 0.0, 0.0, 0.8, 0.8,
1.0,
sql.NullFloat64{},
int16(service.BillingTypeBalance),
int16(service.RequestTypeSync),
false,
false,
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
2,
sql.NullString{Valid: true, String: "4K"},
sql.NullString{Valid: true, String: "1024x1024"},
sql.NullString{Valid: true, String: "3840x2160"},
sql.NullString{Valid: true, String: "output"},
sql.NullString{Valid: true, String: `{"4K":2}`},
sql.NullString{},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
sql.NullString{},
sql.NullFloat64{},
now,
}})
require.NoError(t, err)
require.Equal(t, 2, log.ImageCount)
require.NotNil(t, log.ImageSize)
require.Equal(t, "4K", *log.ImageSize)
require.NotNil(t, log.ImageInputSize)
require.Equal(t, "1024x1024", *log.ImageInputSize)
require.NotNil(t, log.ImageOutputSize)
require.Equal(t, "3840x2160", *log.ImageOutputSize)
require.NotNil(t, log.ImageSizeSource)
require.Equal(t, "output", *log.ImageSizeSource)
require.Equal(t, map[string]int{"4K": 2}, log.ImageSizeBreakdown)
})
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
@ -567,6 +694,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
sql.NullString{}, // image_input_size
sql.NullString{}, // image_output_size
sql.NullString{}, // image_size_source
sql.NullString{}, // image_size_breakdown
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
@ -615,6 +746,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
sql.NullString{}, // image_input_size
sql.NullString{}, // image_output_size
sql.NullString{}, // image_size_source
sql.NullString{}, // image_size_breakdown
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
@ -663,6 +798,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
sql.NullString{}, // image_input_size
sql.NullString{}, // image_output_size
sql.NullString{}, // image_size_source
sql.NullString{}, // image_size_breakdown
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},

View File

@ -334,7 +334,8 @@ func normalizeEmailAuthIdentitySubject(email string) string {
}
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.DingTalkConnectSyntheticEmailDomain) {
return ""
}
return normalized
@ -956,7 +957,7 @@ func userSignupSourceOrDefault(signupSource string) string {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
case "linuxdo", "wechat", "oidc", "dingtalk":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"

View File

@ -68,6 +68,7 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_bound": false,
"oidc_bound": false,
"wechat_bound": false,
"dingtalk_bound": false,
"identities": {
"email": {
"provider": "email",
@ -104,6 +105,14 @@ func TestAPIContracts(t *testing.T) {
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"dingtalk": {
"provider": "dingtalk",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
@ -142,6 +151,14 @@ func TestAPIContracts(t *testing.T) {
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"dingtalk": {
"provider": "dingtalk",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
@ -180,6 +197,14 @@ func TestAPIContracts(t *testing.T) {
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"dingtalk": {
"provider": "dingtalk",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"
@ -554,6 +579,10 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"image_input_size": null,
"image_output_size": null,
"image_size_source": null,
"image_size_breakdown": null,
"media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
@ -672,6 +701,22 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"dingtalk_connect_enabled": false,
"dingtalk_connect_bypass_registration": false,
"dingtalk_connect_client_id": "",
"dingtalk_connect_client_secret_configured": false,
"dingtalk_connect_redirect_url": "",
"dingtalk_connect_internal_corp_id": "",
"dingtalk_connect_corp_restriction_policy": "",
"dingtalk_connect_sync_corp_email": false,
"dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email",
"dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱",
"dingtalk_connect_sync_dept": false,
"dingtalk_connect_sync_dept_attr_key": "dingtalk_department",
"dingtalk_connect_sync_dept_attr_name": "钉钉部门",
"dingtalk_connect_sync_display_name": false,
"dingtalk_connect_sync_display_name_attr_key": "dingtalk_name",
"dingtalk_connect_sync_display_name_attr_name": "钉钉姓名",
"oidc_connect_enabled": false,
"oidc_connect_provider_name": "OIDC",
"oidc_connect_client_id": "",
@ -744,6 +789,11 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
"auth_source_default_dingtalk_balance": 0,
"auth_source_default_dingtalk_concurrency": 5,
"auth_source_default_dingtalk_subscriptions": [],
"auth_source_default_dingtalk_grant_on_signup": false,
"auth_source_default_dingtalk_grant_on_first_bind": false,
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
@ -784,14 +834,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
"rules": []
},
"custom_menu_items": [],
"custom_endpoints": [],
@ -917,6 +960,22 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"dingtalk_connect_enabled": false,
"dingtalk_connect_bypass_registration": false,
"dingtalk_connect_client_id": "",
"dingtalk_connect_client_secret_configured": false,
"dingtalk_connect_redirect_url": "",
"dingtalk_connect_internal_corp_id": "",
"dingtalk_connect_corp_restriction_policy": "",
"dingtalk_connect_sync_corp_email": false,
"dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email",
"dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱",
"dingtalk_connect_sync_dept": false,
"dingtalk_connect_sync_dept_attr_key": "dingtalk_department",
"dingtalk_connect_sync_dept_attr_name": "钉钉部门",
"dingtalk_connect_sync_display_name": false,
"dingtalk_connect_sync_display_name_attr_key": "dingtalk_name",
"dingtalk_connect_sync_display_name_attr_name": "钉钉姓名",
"oidc_connect_enabled": true,
"oidc_connect_provider_name": "ConfigOIDC",
"oidc_connect_client_id": "oidc-config-client",
@ -999,14 +1058,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
"rules": []
},
"payment_enabled": false,
"payment_min_amount": 0,
@ -1084,6 +1136,11 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
"auth_source_default_dingtalk_balance": 0,
"auth_source_default_dingtalk_concurrency": 5,
"auth_source_default_dingtalk_subscriptions": [],
"auth_source_default_dingtalk_grant_on_signup": false,
"auth_source_default_dingtalk_grant_on_first_bind": false,
"force_email_on_third_party_signup": false
}
}`,
@ -1194,10 +1251,10 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {

View File

@ -92,6 +92,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}

View File

@ -333,6 +333,15 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
require.NoError(t, router.SetTrustedProxies(nil))
var markedBusinessLimited bool
var businessLimitedReason string
router.Use(func(c *gin.Context) {
c.Next()
markedBusinessLimited = service.HasOpsClientBusinessLimited(c)
if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok {
businessLimitedReason, _ = v.(string)
}
})
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
@ -349,6 +358,8 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
require.Equal(t, http.StatusForbidden, w.Code)
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
require.True(t, markedBusinessLimited)
require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason)
}
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {

View File

@ -42,15 +42,19 @@ func backendModeAllowsAuthPath(path string) bool {
"/auth/oauth/oidc/callback",
"/auth/oauth/github/callback",
"/auth/oauth/google/callback",
"/auth/oauth/dingtalk/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",
"/auth/oauth/dingtalk/complete-registration",
"/auth/oauth/linuxdo/create-account",
"/auth/oauth/wechat/create-account",
"/auth/oauth/oidc/create-account",
"/auth/oauth/dingtalk/create-account",
"/auth/oauth/linuxdo/bind-login",
"/auth/oauth/wechat/bind-login",
"/auth/oauth/oidc/bind-login",
"/auth/oauth/dingtalk/bind-login",
} {
if strings.HasSuffix(path, suffix) {
return true

View File

@ -270,6 +270,36 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/oauth/google/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_dingtalk_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/dingtalk/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_dingtalk_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/dingtalk/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_dingtalk_complete_registration",
enabled: "true",
path: "/api/v1/auth/oauth/dingtalk/complete-registration",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_dingtalk_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/dingtalk/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_dingtalk_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/dingtalk/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",

View File

@ -303,6 +303,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/:id/models/sync-upstream", h.Admin.Account.SyncUpstreamModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.GET("/data", h.Admin.Account.ExportData)
accounts.POST("/data", h.Admin.Account.ImportData)

View File

@ -182,6 +182,32 @@ func RegisterAuthRoutes(
}),
h.Auth.CreateOIDCOAuthAccount,
)
auth.GET("/oauth/dingtalk/start", h.Auth.DingTalkOAuthStart)
auth.GET("/oauth/dingtalk/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.DingTalkOAuthStart(c)
})
auth.GET("/oauth/dingtalk/callback", h.Auth.DingTalkOAuthCallback)
auth.POST("/oauth/dingtalk/complete-registration",
rateLimiter.LimitWithOptions("oauth-dingtalk-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteDingTalkOAuthRegistration,
)
auth.POST("/oauth/dingtalk/bind-login",
rateLimiter.LimitWithOptions("oauth-dingtalk-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindDingTalkOAuthLogin,
)
auth.POST("/oauth/dingtalk/create-account",
rateLimiter.LimitWithOptions("oauth-dingtalk-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateDingTalkOAuthAccount,
)
}
// 公开设置(无需认证)

View File

@ -0,0 +1,50 @@
package service
// SensitiveCredentialKeys 列出 Account.Credentials JSON map 中绝不允许返回到前端的子键。
// dto 层做响应脱敏、service 层做更新合并都引用此清单——新增凭证类型时务必同步。
var SensitiveCredentialKeys = []string{
// OAuth
"access_token", "refresh_token", "id_token",
// API Key 类
"api_key", "session_key", "cookie",
// 云服务凭据
"aws_secret_access_key", "aws_session_token",
"service_account_json", "service_account", "private_key",
}
var sensitiveCredentialKeySet = func() map[string]struct{} {
m := make(map[string]struct{}, len(SensitiveCredentialKeys))
for _, k := range SensitiveCredentialKeys {
m[k] = struct{}{}
}
return m
}()
// IsSensitiveCredentialKey 判断指定键是否为敏感凭证子键。
func IsSensitiveCredentialKey(key string) bool {
_, ok := sensitiveCredentialKeySet[key]
return ok
}
// MergePreservingSensitiveCreds 把 incoming 写入 existing 之上,但敏感子键采用"incoming 没提供就保留 existing"
// 的语义。返回新的 map不修改入参。
//
// 用途:前端编辑账号通常采用"全对象 PUT"模式;脱敏后前端 spread 旧 credentials 时不会带上敏感键,
// 直接覆盖会清空已有 token。此函数保证
// - 非敏感键:完全由 incoming 决定(用户可以编辑、删除非敏感字段)。
// - 敏感键incoming 显式提供则覆盖(用户主动旋转 token否则保留 existing。
func MergePreservingSensitiveCreds(existing, incoming map[string]any) map[string]any {
out := make(map[string]any, len(incoming)+len(SensitiveCredentialKeys))
for k, v := range incoming {
out[k] = v
}
for _, key := range SensitiveCredentialKeys {
if _, hasIncoming := incoming[key]; hasIncoming {
continue
}
if existingVal, ok := existing[key]; ok {
out[key] = existingVal
}
}
return out
}

View File

@ -0,0 +1,90 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergePreservingSensitiveCreds_PreservesSensitiveWhenIncomingMissing(t *testing.T) {
existing := map[string]any{
"refresh_token": "rt-old",
"access_token": "at-old",
"api_key": "sk-old",
"base_url": "https://old.example.com",
}
incoming := map[string]any{
"base_url": "https://new.example.com",
"model_mapping": map[string]any{"foo": "bar"},
}
out := MergePreservingSensitiveCreds(existing, incoming)
require.Equal(t, "rt-old", out["refresh_token"], "incoming 没传 refresh_token应保留 existing")
require.Equal(t, "at-old", out["access_token"])
require.Equal(t, "sk-old", out["api_key"])
require.Equal(t, "https://new.example.com", out["base_url"], "非敏感键由 incoming 决定")
require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
}
func TestMergePreservingSensitiveCreds_OverwritesWhenIncomingProvidesSensitive(t *testing.T) {
existing := map[string]any{
"refresh_token": "rt-old",
"api_key": "sk-old",
}
incoming := map[string]any{
"refresh_token": "rt-new",
// 显式没传 api_key —— 应保留
}
out := MergePreservingSensitiveCreds(existing, incoming)
require.Equal(t, "rt-new", out["refresh_token"], "incoming 显式传入应覆盖")
require.Equal(t, "sk-old", out["api_key"], "incoming 没传应保留")
}
func TestMergePreservingSensitiveCreds_DoesNotMutateInputs(t *testing.T) {
existing := map[string]any{"refresh_token": "rt"}
incoming := map[string]any{"base_url": "x"}
_ = MergePreservingSensitiveCreds(existing, incoming)
require.Equal(t, "rt", existing["refresh_token"])
require.NotContains(t, existing, "base_url")
require.Equal(t, "x", incoming["base_url"])
require.NotContains(t, incoming, "refresh_token")
}
func TestMergePreservingSensitiveCreds_NilInputs(t *testing.T) {
out := MergePreservingSensitiveCreds(nil, map[string]any{"base_url": "x"})
require.Equal(t, "x", out["base_url"])
require.NotContains(t, out, "refresh_token")
out2 := MergePreservingSensitiveCreds(map[string]any{"refresh_token": "rt"}, nil)
require.Equal(t, "rt", out2["refresh_token"])
}
func TestMergePreservingSensitiveCreds_NonSensitiveDeletionAllowed(t *testing.T) {
existing := map[string]any{
"refresh_token": "rt",
"base_url": "https://old",
"project_id": "p1",
}
incoming := map[string]any{
"base_url": "https://new",
// 不带 project_id —— 等同删除(非敏感键由 incoming 决定)
}
out := MergePreservingSensitiveCreds(existing, incoming)
require.Equal(t, "rt", out["refresh_token"], "敏感键保留")
require.Equal(t, "https://new", out["base_url"])
require.NotContains(t, out, "project_id", "非敏感键 incoming 不传 = 删除")
}
func TestIsSensitiveCredentialKey(t *testing.T) {
require.True(t, IsSensitiveCredentialKey("refresh_token"))
require.True(t, IsSensitiveCredentialKey("api_key"))
require.True(t, IsSensitiveCredentialKey("private_key"))
require.False(t, IsSensitiveCredentialKey("base_url"))
require.False(t, IsSensitiveCredentialKey(""))
require.False(t, IsSensitiveCredentialKey("model_mapping"))
}

View File

@ -397,6 +397,7 @@ type GenerateRedeemCodesInput struct {
Value float64
GroupID *int64 // 订阅类型专用关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
ExpiresAt *time.Time
}
type ProxyBatchDeleteResult struct {
@ -1238,7 +1239,7 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
providerKey := strings.TrimSpace(input.ProviderKey)
providerSubject := strings.TrimSpace(input.ProviderSubject)
if providerType == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, wechat, or dingtalk")
}
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
@ -1493,6 +1494,8 @@ func normalizeAdminAuthIdentityProviderType(input string) string {
return "oidc"
case "wechat":
return "wechat"
case "dingtalk":
return "dingtalk"
default:
return ""
}
@ -2470,7 +2473,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
// 敏感子键采用"incoming 没提供就保留"的合并语义:前端响应已脱敏,
// 全对象 PUT 编辑时不会再带回 token避免覆盖时清空已有凭证。
account.Credentials = MergePreservingSensitiveCreds(account.Credentials, input.Credentials)
}
// Extra 使用 map需要区分“未提供(nil)”与“显式清空({})”。
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
@ -2966,6 +2971,10 @@ func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*Redeem
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
if input.ExpiresAt != nil && !input.ExpiresAt.After(time.Now()) {
return nil, ErrRedeemCodeExpired
}
// 如果是订阅类型,验证必须有 GroupID
if input.Type == RedeemTypeSubscription {
if input.GroupID == nil {
@ -2988,10 +2997,11 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
return nil, err
}
code := RedeemCode{
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: StatusUnused,
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: StatusUnused,
ExpiresAt: input.ExpiresAt,
}
// 订阅类型专用字段
if input.Type == RedeemTypeSubscription {

View File

@ -0,0 +1,117 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type updateAccountCredsRepoStub struct {
mockAccountRepoForGemini
account *Account
updateCalls int
}
func (r *updateAccountCredsRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
return r.account, nil
}
func (r *updateAccountCredsRepoStub) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.account = account
return nil
}
func TestUpdateAccount_PreservesSensitiveCredsWhenIncomingOmits(t *testing.T) {
accountID := int64(202)
repo := &updateAccountCredsRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Status: StatusActive,
Credentials: map[string]any{
"refresh_token": "rt-existing",
"access_token": "at-existing",
"id_token": "id-existing",
"base_url": "https://old.example.com",
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
// 模拟前端编辑:仅修改 base_url没有传 token脱敏后前端 spread 拿不到敏感键)
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Credentials: map[string]any{
"base_url": "https://new.example.com",
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
// 敏感键应保留
require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"])
require.Equal(t, "at-existing", repo.account.Credentials["access_token"])
require.Equal(t, "id-existing", repo.account.Credentials["id_token"])
// 非敏感键被替换
require.Equal(t, "https://new.example.com", repo.account.Credentials["base_url"])
}
func TestUpdateAccount_ExplicitNewTokenOverwrites(t *testing.T) {
accountID := int64(203)
repo := &updateAccountCredsRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Status: StatusActive,
Credentials: map[string]any{
"refresh_token": "rt-old",
"api_key": "sk-old",
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Credentials: map[string]any{
"refresh_token": "rt-new",
// api_key 没传 → 应保留旧值
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "rt-new", repo.account.Credentials["refresh_token"])
require.Equal(t, "sk-old", repo.account.Credentials["api_key"])
}
func TestUpdateAccount_EmptyCredentialsSkipsUpdate(t *testing.T) {
accountID := int64(204)
repo := &updateAccountCredsRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Status: StatusActive,
Credentials: map[string]any{
"refresh_token": "rt-existing",
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
_, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Credentials: map[string]any{}, // len == 0 → 闸门跳过
Name: "renamed",
})
require.NoError(t, err)
require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"], "空 credentials 不应触碰已有 token")
require.Equal(t, "renamed", repo.account.Name)
}

View File

@ -2094,7 +2094,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
// 解析请求以获取 image_size用于图片计费
imageSize := s.extractImageSize(body)
imageInputSize := s.extractImageInputSize(body)
imageSize := normalizeOpenAIImageSizeTier(imageInputSize)
switch action {
case "generateContent", "streamGenerateContent":
@ -2465,6 +2466,7 @@ handleSuccess:
ClientDisconnect: clientDisconnect,
ImageCount: imageCount,
ImageSize: imageSize,
ImageInputSize: imageInputSize,
}, nil
}
@ -4063,21 +4065,17 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
func (s *AntigravityGatewayService) extractImageInputSize(body []byte) string {
var req antigravity.GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
return "2K" // 默认 2K
return ""
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
return strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)
}
return "2K" // 默认 2K
return ""
}
// isImageGenerationModel 判断模型是否为图片生成模型

View File

@ -46,15 +46,15 @@ func TestExtractImageSize_ValidSizes(t *testing.T) {
// 1K
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
require.Equal(t, "1K", svc.extractImageSize(body))
require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 2K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 4K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
require.Equal(t, "4K", svc.extractImageSize(body))
require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
@ -62,10 +62,10 @@ func TestExtractImageSize_CaseInsensitive(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
require.Equal(t, "1K", svc.extractImageSize(body))
require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
require.Equal(t, "4K", svc.extractImageSize(body))
require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
@ -74,15 +74,15 @@ func TestExtractImageSize_Default(t *testing.T) {
// 无 generationConfig
body := []byte(`{"contents":[]}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 有 generationConfig 但无 imageConfig
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 有 imageConfig 但无 imageSize
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
@ -90,10 +90,10 @@ func TestExtractImageSize_InvalidJSON(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`not valid json`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"broken":`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
@ -101,11 +101,11 @@ func TestExtractImageSize_EmptySize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 空格
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
@ -113,11 +113,11 @@ func TestExtractImageSize_InvalidSize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"net/mail"
"strings"
"time"
@ -18,7 +19,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
switch signupSource {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc", "github", "google":
case "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
return signupSource
default:
return "email"
@ -71,7 +72,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
if err != nil {
return nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
return nil, ErrInvitationCodeInvalid
}
return redeemCode, nil
@ -109,7 +110,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
return nil, nil, ErrRegDisabled
}
@ -118,18 +119,22 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
slog.Error("oauth email register: policy rejected", "email", email, "error", err.Error())
return nil, nil, err
}
if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
slog.Error("oauth email register: verify code failed", "email", email, "error", err.Error())
return nil, nil, err
}
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
slog.Error("oauth email register: invitation failed", "email", email, "error", err.Error())
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
slog.Error("oauth email register: ExistsByEmail failed", "email", email, "error", err.Error())
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
@ -158,6 +163,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
slog.Error("oauth email register: userRepo.Create failed", "email", email, "signup_source", signupSource, "error", err.Error())
return nil, nil, ErrServiceUnavailable
}
@ -181,7 +187,7 @@ func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
return nil, nil, ErrRegDisabled
}
@ -358,6 +364,7 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit
UsedAt: entity.UsedAt,
Notes: oauthEmailFlowStringValue(entity.Notes),
CreatedAt: entity.CreatedAt,
ExpiresAt: entity.ExpiresAt,
GroupID: entity.GroupID,
ValidityDays: entity.ValidityDays,
}, nil
@ -368,7 +375,11 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit
func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
if client := s.oauthEmailFlowClient(ctx); client != nil {
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
Where(
redeemcode.IDEQ(invitationID),
redeemcode.StatusEQ(StatusUnused),
redeemcode.Or(redeemcode.ExpiresAtIsNil(), redeemcode.ExpiresAtGT(time.Now().UTC())),
).
SetStatus(StatusUsed).
SetUsedBy(userID).
SetUsedAt(time.Now().UTC()).
@ -396,6 +407,11 @@ func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, cod
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays)
if code.ExpiresAt != nil {
update = update.SetExpiresAt(*code.ExpiresAt)
} else {
update = update.ClearExpiresAt()
}
if code.UsedBy != nil {
update = update.SetUsedBy(*code.UsedBy)
} else {

View File

@ -157,7 +157,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrInvitationCodeInvalid
}
// 检查类型和状态
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
return "", nil, ErrInvitationCodeInvalid
}
@ -560,11 +560,25 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil
}
// canBypassRegistrationDisabledForOAuth 在钉钉企业模式internal_only
// dingtalk_connect_bypass_registration=true 时,允许跳过全局 registration_enabled 检查。
func (s *AuthService) canBypassRegistrationDisabledForOAuth(ctx context.Context, signupSource string) bool {
if signupSource != "dingtalk" {
return false
}
cfg, err := s.settingService.GetDingTalkConnectOAuthConfig(ctx)
if err != nil || !cfg.Enabled || !cfg.BypassRegistration {
return false
}
return cfg.CorpRestrictionPolicy == "internal_only"
}
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// signupSource 标识来源渠道("dingtalk"/"linuxdo"/"wechat"/"oidc" 等),仅用于豁免检查。
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode, signupSource string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@ -587,7 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
return nil, nil, ErrRegDisabled
}
@ -601,7 +615,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
@ -617,7 +631,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource := inferLegacySignupSource(email)
// 优先用 caller 显式传入的 signupSource如 "dingtalk" / "linuxdo" / "oidc" / "wechat"
// 否则才按邮箱后缀推断——避免有真实邮箱的 OAuth 用户被推断为 "email" 渠道,导致渠道授权错读。
if strings.TrimSpace(signupSource) == "" {
signupSource = inferLegacySignupSource(email)
}
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
@ -779,6 +797,8 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
return defaults.GitHub, true
case "google":
return defaults.Google, true
case "dingtalk":
return defaults.DingTalk, true
default:
return ProviderDefaultGrantSettings{}, false
}
@ -992,6 +1012,8 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
case strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain):
return "dingtalk"
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
@ -1086,7 +1108,8 @@ func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token

View File

@ -602,7 +602,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 9.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Equal(t, 5, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(31), assigner.calls[0].GroupID)
require.Equal(t, 5, assigner.calls[0].ValidityDays)
@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "", "linuxdo")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "", "linuxdo")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
@ -667,3 +667,99 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
require.Empty(t, repo.created)
require.Empty(t, assigner.calls)
}
// newAuthServiceWithDingTalkCfg 构建一个含完整 DingTalk config 的 AuthService
// 用于测试 canBypassRegistrationDisabledForOAuth。
func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.DingTalkConnectConfig) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1},
Default: config.DefaultConfig{UserBalance: 3.5, UserConcurrency: 2},
DingTalk: dtCfg,
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil)
}
// minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig不设 Enabled/BypassRegistration/Policy
func minDingTalkURLs() config.DingTalkConnectConfig {
return config.DingTalkConnectConfig{
ClientID: "test-client",
ClientSecret: "test-secret",
AuthorizeURL: "https://example.com/oauth2/auth",
TokenURL: "https://example.com/oauth2/token",
UserInfoURL: "https://example.com/oauth2/userinfo",
RedirectURL: "https://example.com/callback",
FrontendRedirectURL: "https://example.com/auth/callback",
DingTalkAppKind: "internal_app",
AppType: "internal",
}
}
func TestCanBypassRegistrationDisabledForOAuth(t *testing.T) {
cases := []struct {
name string
signupSource string
settings map[string]string
dtCfg config.DingTalkConnectConfig
want bool
}{
{
name: "non-dingtalk source → false",
signupSource: "linuxdo",
settings: map[string]string{},
dtCfg: minDingTalkURLs(),
want: false,
},
{
name: "dingtalk but cfg.Enabled=false → false",
signupSource: "dingtalk",
settings: map[string]string{
SettingKeyDingTalkConnectEnabled: "false",
SettingKeyDingTalkConnectBypassRegistration: "true",
SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
},
dtCfg: minDingTalkURLs(),
want: false,
},
{
name: "dingtalk enabled but BypassRegistration=false → false",
signupSource: "dingtalk",
settings: map[string]string{
SettingKeyDingTalkConnectEnabled: "true",
SettingKeyDingTalkConnectBypassRegistration: "false",
SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
},
dtCfg: minDingTalkURLs(),
want: false,
},
{
name: "dingtalk enabled + bypass=true but policy=none → false",
signupSource: "dingtalk",
settings: map[string]string{
SettingKeyDingTalkConnectEnabled: "true",
SettingKeyDingTalkConnectBypassRegistration: "true",
SettingKeyDingTalkConnectCorpRestrictionPolicy: "none",
},
dtCfg: minDingTalkURLs(),
want: false,
},
{
name: "dingtalk enabled + bypass=true + policy=internal_only → true",
signupSource: "dingtalk",
settings: map[string]string{
SettingKeyDingTalkConnectEnabled: "true",
SettingKeyDingTalkConnectBypassRegistration: "true",
SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
},
dtCfg: minDingTalkURLs(),
want: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
svc := newAuthServiceWithDingTalkCfg(tc.settings, tc.dtCfg)
got := svc.canBypassRegistrationDisabledForOAuth(context.Background(), tc.signupSource)
require.Equal(t, tc.want, got)
})
}
}

View File

@ -0,0 +1,13 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsReservedEmail_DingTalkDomain(t *testing.T) {
require.True(t, isReservedEmail("dingtalk-123@dingtalk-connect.invalid"))
require.True(t, isReservedEmail("DINGTALK-456@DINGTALK-CONNECT.INVALID")) // case-insensitive
require.False(t, isReservedEmail("real@dingtalk.com"))
}

View File

@ -809,6 +809,7 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
if imageCount <= 0 {
return &CostBreakdown{}
}
imageSize = NormalizeImageBillingTierOrDefault(imageSize)
// 获取单价
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)

Some files were not shown because too many files have changed in this diff Show More