Merge origin/main into fix/deepseek-reasoning-content
This commit is contained in:
commit
6082d02d22
@ -20,8 +20,8 @@ FROM ${NODE_IMAGE} AS frontend-builder
|
||||
|
||||
WORKDIR /app/frontend
|
||||
|
||||
# Install pnpm
|
||||
RUN corepack enable && corepack prepare pnpm@latest --activate
|
||||
# Install pnpm (pinned to v9 to match CI and keep builds reproducible)
|
||||
RUN corepack enable && corepack prepare pnpm@9 --activate
|
||||
|
||||
# Install dependencies first (better caching)
|
||||
COPY frontend/package.json frontend/pnpm-lock.yaml ./
|
||||
|
||||
@ -81,7 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
}
|
||||
totpCache := repository.NewTotpCache(redisClient)
|
||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
@ -198,7 +201,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
registry := payment.ProvideRegistry()
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
@ -211,9 +214,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
|
||||
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
|
||||
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||
errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
|
||||
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
|
||||
|
||||
@ -1120,6 +1120,7 @@ var (
|
||||
{Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "validity_days", Type: field.TypeInt, Default: 30},
|
||||
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "used_by", Type: field.TypeInt64, Nullable: true},
|
||||
@ -1132,13 +1133,13 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "redeem_codes_groups_redeem_codes",
|
||||
Columns: []*schema.Column{RedeemCodesColumns[9]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "redeem_codes_users_redeem_codes",
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[11]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@ -1152,12 +1153,17 @@ var (
|
||||
{
|
||||
Name: "redeemcode_used_by",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[11]},
|
||||
},
|
||||
{
|
||||
Name: "redeemcode_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{RedeemCodesColumns[9]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "redeemcode_expires_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{RedeemCodesColumns[8]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -1318,6 +1324,10 @@ var (
|
||||
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||
{Name: "image_input_size", Type: field.TypeString, Nullable: true, Size: 32},
|
||||
{Name: "image_output_size", Type: field.TypeString, Nullable: true, Size: 32},
|
||||
{Name: "image_size_source", Type: field.TypeString, Nullable: true, Size: 16},
|
||||
{Name: "image_size_breakdown", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "api_key_id", Type: field.TypeInt64},
|
||||
@ -1334,31 +1344,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[39]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[40]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[41]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@ -1367,32 +1377,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[40]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[39]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[41]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@ -1412,17 +1422,17 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[40], UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[39], UsageLogsColumns[36]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -28602,6 +28602,7 @@ type RedeemCodeMutation struct {
|
||||
used_at *time.Time
|
||||
notes *string
|
||||
created_at *time.Time
|
||||
expires_at *time.Time
|
||||
validity_days *int
|
||||
addvalidity_days *int
|
||||
clearedFields map[string]struct{}
|
||||
@ -29059,6 +29060,55 @@ func (m *RedeemCodeMutation) ResetCreatedAt() {
|
||||
m.created_at = nil
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (m *RedeemCodeMutation) SetExpiresAt(t time.Time) {
|
||||
m.expires_at = &t
|
||||
}
|
||||
|
||||
// ExpiresAt returns the value of the "expires_at" field in the mutation.
|
||||
func (m *RedeemCodeMutation) ExpiresAt() (r time.Time, exists bool) {
|
||||
v := m.expires_at
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldExpiresAt returns the old "expires_at" field's value of the RedeemCode entity.
|
||||
// If the RedeemCode object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *RedeemCodeMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldExpiresAt requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
|
||||
}
|
||||
return oldValue.ExpiresAt, nil
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (m *RedeemCodeMutation) ClearExpiresAt() {
|
||||
m.expires_at = nil
|
||||
m.clearedFields[redeemcode.FieldExpiresAt] = struct{}{}
|
||||
}
|
||||
|
||||
// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
|
||||
func (m *RedeemCodeMutation) ExpiresAtCleared() bool {
|
||||
_, ok := m.clearedFields[redeemcode.FieldExpiresAt]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetExpiresAt resets all changes to the "expires_at" field.
|
||||
func (m *RedeemCodeMutation) ResetExpiresAt() {
|
||||
m.expires_at = nil
|
||||
delete(m.clearedFields, redeemcode.FieldExpiresAt)
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (m *RedeemCodeMutation) SetGroupID(i int64) {
|
||||
m.group = &i
|
||||
@ -29265,7 +29315,7 @@ func (m *RedeemCodeMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *RedeemCodeMutation) Fields() []string {
|
||||
fields := make([]string, 0, 10)
|
||||
fields := make([]string, 0, 11)
|
||||
if m.code != nil {
|
||||
fields = append(fields, redeemcode.FieldCode)
|
||||
}
|
||||
@ -29290,6 +29340,9 @@ func (m *RedeemCodeMutation) Fields() []string {
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, redeemcode.FieldCreatedAt)
|
||||
}
|
||||
if m.expires_at != nil {
|
||||
fields = append(fields, redeemcode.FieldExpiresAt)
|
||||
}
|
||||
if m.group != nil {
|
||||
fields = append(fields, redeemcode.FieldGroupID)
|
||||
}
|
||||
@ -29320,6 +29373,8 @@ func (m *RedeemCodeMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Notes()
|
||||
case redeemcode.FieldCreatedAt:
|
||||
return m.CreatedAt()
|
||||
case redeemcode.FieldExpiresAt:
|
||||
return m.ExpiresAt()
|
||||
case redeemcode.FieldGroupID:
|
||||
return m.GroupID()
|
||||
case redeemcode.FieldValidityDays:
|
||||
@ -29349,6 +29404,8 @@ func (m *RedeemCodeMutation) OldField(ctx context.Context, name string) (ent.Val
|
||||
return m.OldNotes(ctx)
|
||||
case redeemcode.FieldCreatedAt:
|
||||
return m.OldCreatedAt(ctx)
|
||||
case redeemcode.FieldExpiresAt:
|
||||
return m.OldExpiresAt(ctx)
|
||||
case redeemcode.FieldGroupID:
|
||||
return m.OldGroupID(ctx)
|
||||
case redeemcode.FieldValidityDays:
|
||||
@ -29418,6 +29475,13 @@ func (m *RedeemCodeMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetCreatedAt(v)
|
||||
return nil
|
||||
case redeemcode.FieldExpiresAt:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetExpiresAt(v)
|
||||
return nil
|
||||
case redeemcode.FieldGroupID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
@ -29498,6 +29562,9 @@ func (m *RedeemCodeMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(redeemcode.FieldNotes) {
|
||||
fields = append(fields, redeemcode.FieldNotes)
|
||||
}
|
||||
if m.FieldCleared(redeemcode.FieldExpiresAt) {
|
||||
fields = append(fields, redeemcode.FieldExpiresAt)
|
||||
}
|
||||
if m.FieldCleared(redeemcode.FieldGroupID) {
|
||||
fields = append(fields, redeemcode.FieldGroupID)
|
||||
}
|
||||
@ -29524,6 +29591,9 @@ func (m *RedeemCodeMutation) ClearField(name string) error {
|
||||
case redeemcode.FieldNotes:
|
||||
m.ClearNotes()
|
||||
return nil
|
||||
case redeemcode.FieldExpiresAt:
|
||||
m.ClearExpiresAt()
|
||||
return nil
|
||||
case redeemcode.FieldGroupID:
|
||||
m.ClearGroupID()
|
||||
return nil
|
||||
@ -29559,6 +29629,9 @@ func (m *RedeemCodeMutation) ResetField(name string) error {
|
||||
case redeemcode.FieldCreatedAt:
|
||||
m.ResetCreatedAt()
|
||||
return nil
|
||||
case redeemcode.FieldExpiresAt:
|
||||
m.ResetExpiresAt()
|
||||
return nil
|
||||
case redeemcode.FieldGroupID:
|
||||
m.ResetGroupID()
|
||||
return nil
|
||||
@ -34260,6 +34333,10 @@ type UsageLogMutation struct {
|
||||
image_count *int
|
||||
addimage_count *int
|
||||
image_size *string
|
||||
image_input_size *string
|
||||
image_output_size *string
|
||||
image_size_source *string
|
||||
image_size_breakdown *map[string]int
|
||||
cache_ttl_overridden *bool
|
||||
created_at *time.Time
|
||||
clearedFields map[string]struct{}
|
||||
@ -36202,6 +36279,202 @@ func (m *UsageLogMutation) ResetImageSize() {
|
||||
delete(m.clearedFields, usagelog.FieldImageSize)
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (m *UsageLogMutation) SetImageInputSize(s string) {
|
||||
m.image_input_size = &s
|
||||
}
|
||||
|
||||
// ImageInputSize returns the value of the "image_input_size" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageInputSize() (r string, exists bool) {
|
||||
v := m.image_input_size
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageInputSize returns the old "image_input_size" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageInputSize(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageInputSize is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageInputSize requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageInputSize: %w", err)
|
||||
}
|
||||
return oldValue.ImageInputSize, nil
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (m *UsageLogMutation) ClearImageInputSize() {
|
||||
m.image_input_size = nil
|
||||
m.clearedFields[usagelog.FieldImageInputSize] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageInputSizeCleared returns if the "image_input_size" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageInputSizeCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageInputSize]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageInputSize resets all changes to the "image_input_size" field.
|
||||
func (m *UsageLogMutation) ResetImageInputSize() {
|
||||
m.image_input_size = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageInputSize)
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (m *UsageLogMutation) SetImageOutputSize(s string) {
|
||||
m.image_output_size = &s
|
||||
}
|
||||
|
||||
// ImageOutputSize returns the value of the "image_output_size" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageOutputSize() (r string, exists bool) {
|
||||
v := m.image_output_size
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageOutputSize returns the old "image_output_size" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageOutputSize(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageOutputSize is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageOutputSize requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageOutputSize: %w", err)
|
||||
}
|
||||
return oldValue.ImageOutputSize, nil
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (m *UsageLogMutation) ClearImageOutputSize() {
|
||||
m.image_output_size = nil
|
||||
m.clearedFields[usagelog.FieldImageOutputSize] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageOutputSizeCleared returns if the "image_output_size" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageOutputSizeCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageOutputSize]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageOutputSize resets all changes to the "image_output_size" field.
|
||||
func (m *UsageLogMutation) ResetImageOutputSize() {
|
||||
m.image_output_size = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageOutputSize)
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (m *UsageLogMutation) SetImageSizeSource(s string) {
|
||||
m.image_size_source = &s
|
||||
}
|
||||
|
||||
// ImageSizeSource returns the value of the "image_size_source" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageSizeSource() (r string, exists bool) {
|
||||
v := m.image_size_source
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageSizeSource returns the old "image_size_source" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageSizeSource(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageSizeSource is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageSizeSource requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageSizeSource: %w", err)
|
||||
}
|
||||
return oldValue.ImageSizeSource, nil
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (m *UsageLogMutation) ClearImageSizeSource() {
|
||||
m.image_size_source = nil
|
||||
m.clearedFields[usagelog.FieldImageSizeSource] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageSizeSourceCleared returns if the "image_size_source" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageSizeSourceCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageSizeSource]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageSizeSource resets all changes to the "image_size_source" field.
|
||||
func (m *UsageLogMutation) ResetImageSizeSource() {
|
||||
m.image_size_source = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageSizeSource)
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (m *UsageLogMutation) SetImageSizeBreakdown(value map[string]int) {
|
||||
m.image_size_breakdown = &value
|
||||
}
|
||||
|
||||
// ImageSizeBreakdown returns the value of the "image_size_breakdown" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageSizeBreakdown() (r map[string]int, exists bool) {
|
||||
v := m.image_size_breakdown
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageSizeBreakdown returns the old "image_size_breakdown" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageSizeBreakdown(ctx context.Context) (v map[string]int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageSizeBreakdown is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageSizeBreakdown requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageSizeBreakdown: %w", err)
|
||||
}
|
||||
return oldValue.ImageSizeBreakdown, nil
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (m *UsageLogMutation) ClearImageSizeBreakdown() {
|
||||
m.image_size_breakdown = nil
|
||||
m.clearedFields[usagelog.FieldImageSizeBreakdown] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageSizeBreakdownCleared returns if the "image_size_breakdown" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageSizeBreakdownCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageSizeBreakdown]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageSizeBreakdown resets all changes to the "image_size_breakdown" field.
|
||||
func (m *UsageLogMutation) ResetImageSizeBreakdown() {
|
||||
m.image_size_breakdown = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageSizeBreakdown)
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
|
||||
m.cache_ttl_overridden = &b
|
||||
@ -36443,7 +36716,7 @@ func (m *UsageLogMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UsageLogMutation) Fields() []string {
|
||||
fields := make([]string, 0, 37)
|
||||
fields := make([]string, 0, 41)
|
||||
if m.user != nil {
|
||||
fields = append(fields, usagelog.FieldUserID)
|
||||
}
|
||||
@ -36549,6 +36822,18 @@ func (m *UsageLogMutation) Fields() []string {
|
||||
if m.image_size != nil {
|
||||
fields = append(fields, usagelog.FieldImageSize)
|
||||
}
|
||||
if m.image_input_size != nil {
|
||||
fields = append(fields, usagelog.FieldImageInputSize)
|
||||
}
|
||||
if m.image_output_size != nil {
|
||||
fields = append(fields, usagelog.FieldImageOutputSize)
|
||||
}
|
||||
if m.image_size_source != nil {
|
||||
fields = append(fields, usagelog.FieldImageSizeSource)
|
||||
}
|
||||
if m.image_size_breakdown != nil {
|
||||
fields = append(fields, usagelog.FieldImageSizeBreakdown)
|
||||
}
|
||||
if m.cache_ttl_overridden != nil {
|
||||
fields = append(fields, usagelog.FieldCacheTTLOverridden)
|
||||
}
|
||||
@ -36633,6 +36918,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ImageCount()
|
||||
case usagelog.FieldImageSize:
|
||||
return m.ImageSize()
|
||||
case usagelog.FieldImageInputSize:
|
||||
return m.ImageInputSize()
|
||||
case usagelog.FieldImageOutputSize:
|
||||
return m.ImageOutputSize()
|
||||
case usagelog.FieldImageSizeSource:
|
||||
return m.ImageSizeSource()
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
return m.ImageSizeBreakdown()
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
return m.CacheTTLOverridden()
|
||||
case usagelog.FieldCreatedAt:
|
||||
@ -36716,6 +37009,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
||||
return m.OldImageCount(ctx)
|
||||
case usagelog.FieldImageSize:
|
||||
return m.OldImageSize(ctx)
|
||||
case usagelog.FieldImageInputSize:
|
||||
return m.OldImageInputSize(ctx)
|
||||
case usagelog.FieldImageOutputSize:
|
||||
return m.OldImageOutputSize(ctx)
|
||||
case usagelog.FieldImageSizeSource:
|
||||
return m.OldImageSizeSource(ctx)
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
return m.OldImageSizeBreakdown(ctx)
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
return m.OldCacheTTLOverridden(ctx)
|
||||
case usagelog.FieldCreatedAt:
|
||||
@ -36974,6 +37275,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetImageSize(v)
|
||||
return nil
|
||||
case usagelog.FieldImageInputSize:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageInputSize(v)
|
||||
return nil
|
||||
case usagelog.FieldImageOutputSize:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageOutputSize(v)
|
||||
return nil
|
||||
case usagelog.FieldImageSizeSource:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageSizeSource(v)
|
||||
return nil
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
v, ok := value.(map[string]int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageSizeBreakdown(v)
|
||||
return nil
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
@ -37291,6 +37620,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(usagelog.FieldImageSize) {
|
||||
fields = append(fields, usagelog.FieldImageSize)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageInputSize) {
|
||||
fields = append(fields, usagelog.FieldImageInputSize)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageOutputSize) {
|
||||
fields = append(fields, usagelog.FieldImageOutputSize)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageSizeSource) {
|
||||
fields = append(fields, usagelog.FieldImageSizeSource)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageSizeBreakdown) {
|
||||
fields = append(fields, usagelog.FieldImageSizeBreakdown)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@ -37347,6 +37688,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
||||
case usagelog.FieldImageSize:
|
||||
m.ClearImageSize()
|
||||
return nil
|
||||
case usagelog.FieldImageInputSize:
|
||||
m.ClearImageInputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageOutputSize:
|
||||
m.ClearImageOutputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeSource:
|
||||
m.ClearImageSizeSource()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
m.ClearImageSizeBreakdown()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown UsageLog nullable field %s", name)
|
||||
}
|
||||
@ -37460,6 +37813,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
||||
case usagelog.FieldImageSize:
|
||||
m.ResetImageSize()
|
||||
return nil
|
||||
case usagelog.FieldImageInputSize:
|
||||
m.ResetImageInputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageOutputSize:
|
||||
m.ResetImageOutputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeSource:
|
||||
m.ResetImageSizeSource()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
m.ResetImageSizeBreakdown()
|
||||
return nil
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
m.ResetCacheTTLOverridden()
|
||||
return nil
|
||||
|
||||
@ -35,6 +35,8 @@ type RedeemCode struct {
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// ExpiresAt holds the value of the "expires_at" field.
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
// GroupID holds the value of the "group_id" field.
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
// ValidityDays holds the value of the "validity_days" field.
|
||||
@ -89,7 +91,7 @@ func (*RedeemCode) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullInt64)
|
||||
case redeemcode.FieldCode, redeemcode.FieldType, redeemcode.FieldStatus, redeemcode.FieldNotes:
|
||||
values[i] = new(sql.NullString)
|
||||
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt:
|
||||
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt, redeemcode.FieldExpiresAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
@ -163,6 +165,13 @@ func (_m *RedeemCode) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case redeemcode.FieldExpiresAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ExpiresAt = new(time.Time)
|
||||
*_m.ExpiresAt = value.Time
|
||||
}
|
||||
case redeemcode.FieldGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||
@ -252,6 +261,11 @@ func (_m *RedeemCode) String() string {
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ExpiresAt; v != nil {
|
||||
builder.WriteString("expires_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.GroupID; v != nil {
|
||||
builder.WriteString("group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
|
||||
@ -30,6 +30,8 @@ const (
|
||||
FieldNotes = "notes"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||
FieldExpiresAt = "expires_at"
|
||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||
FieldGroupID = "group_id"
|
||||
// FieldValidityDays holds the string denoting the validity_days field in the database.
|
||||
@ -67,6 +69,7 @@ var Columns = []string{
|
||||
FieldUsedAt,
|
||||
FieldNotes,
|
||||
FieldCreatedAt,
|
||||
FieldExpiresAt,
|
||||
FieldGroupID,
|
||||
FieldValidityDays,
|
||||
}
|
||||
@ -148,6 +151,11 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByExpiresAt orders the results by the expires_at field.
|
||||
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByGroupID orders the results by the group_id field.
|
||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||
|
||||
@ -95,6 +95,11 @@ func CreatedAt(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
|
||||
func ExpiresAt(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||
func GroupID(v int64) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
|
||||
@ -535,6 +540,56 @@ func CreatedAtLTE(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
|
||||
func ExpiresAtEQ(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
|
||||
func ExpiresAtNEQ(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldNEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIn applies the In predicate on the "expires_at" field.
|
||||
func ExpiresAtIn(vs ...time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
|
||||
func ExpiresAtNotIn(vs ...time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldNotIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
|
||||
func ExpiresAtGT(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldGT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
|
||||
func ExpiresAtGTE(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldGTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
|
||||
func ExpiresAtLT(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldLT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
|
||||
func ExpiresAtLTE(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldLTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
|
||||
func ExpiresAtIsNil() predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldIsNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
|
||||
func ExpiresAtNotNil() predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldNotNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||
func GroupIDEQ(v int64) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
|
||||
|
||||
@ -128,6 +128,20 @@ func (_c *RedeemCodeCreate) SetNillableCreatedAt(v *time.Time) *RedeemCodeCreate
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_c *RedeemCodeCreate) SetExpiresAt(v time.Time) *RedeemCodeCreate {
|
||||
_c.mutation.SetExpiresAt(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_c *RedeemCodeCreate) SetNillableExpiresAt(v *time.Time) *RedeemCodeCreate {
|
||||
if v != nil {
|
||||
_c.SetExpiresAt(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_c *RedeemCodeCreate) SetGroupID(v int64) *RedeemCodeCreate {
|
||||
_c.mutation.SetGroupID(v)
|
||||
@ -327,6 +341,10 @@ func (_c *RedeemCodeCreate) createSpec() (*RedeemCode, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(redeemcode.FieldCreatedAt, field.TypeTime, value)
|
||||
_node.CreatedAt = value
|
||||
}
|
||||
if value, ok := _c.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
|
||||
_node.ExpiresAt = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ValidityDays(); ok {
|
||||
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
|
||||
_node.ValidityDays = value
|
||||
@ -525,6 +543,24 @@ func (u *RedeemCodeUpsert) ClearNotes() *RedeemCodeUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *RedeemCodeUpsert) SetExpiresAt(v time.Time) *RedeemCodeUpsert {
|
||||
u.Set(redeemcode.FieldExpiresAt, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *RedeemCodeUpsert) UpdateExpiresAt() *RedeemCodeUpsert {
|
||||
u.SetExcluded(redeemcode.FieldExpiresAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *RedeemCodeUpsert) ClearExpiresAt() *RedeemCodeUpsert {
|
||||
u.SetNull(redeemcode.FieldExpiresAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *RedeemCodeUpsert) SetGroupID(v int64) *RedeemCodeUpsert {
|
||||
u.Set(redeemcode.FieldGroupID, v)
|
||||
@ -732,6 +768,27 @@ func (u *RedeemCodeUpsertOne) ClearNotes() *RedeemCodeUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertOne) SetExpiresAt(v time.Time) *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.SetExpiresAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *RedeemCodeUpsertOne) UpdateExpiresAt() *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.UpdateExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertOne) ClearExpiresAt() *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.ClearExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *RedeemCodeUpsertOne) SetGroupID(v int64) *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
@ -1111,6 +1168,27 @@ func (u *RedeemCodeUpsertBulk) ClearNotes() *RedeemCodeUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertBulk) SetExpiresAt(v time.Time) *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.SetExpiresAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *RedeemCodeUpsertBulk) UpdateExpiresAt() *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.UpdateExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertBulk) ClearExpiresAt() *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.ClearExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *RedeemCodeUpsertBulk) SetGroupID(v int64) *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
|
||||
@ -153,6 +153,26 @@ func (_u *RedeemCodeUpdate) ClearNotes() *RedeemCodeUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdate) SetExpiresAt(v time.Time) *RedeemCodeUpdate {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *RedeemCodeUpdate) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdate) ClearExpiresAt() *RedeemCodeUpdate {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *RedeemCodeUpdate) SetGroupID(v int64) *RedeemCodeUpdate {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@ -321,6 +341,12 @@ func (_u *RedeemCodeUpdate) sqlSave(ctx context.Context) (_node int, err error)
|
||||
if _u.mutation.NotesCleared() {
|
||||
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.ValidityDays(); ok {
|
||||
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
|
||||
}
|
||||
@ -528,6 +554,26 @@ func (_u *RedeemCodeUpdateOne) ClearNotes() *RedeemCodeUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdateOne) SetExpiresAt(v time.Time) *RedeemCodeUpdateOne {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *RedeemCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdateOne) ClearExpiresAt() *RedeemCodeUpdateOne {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *RedeemCodeUpdateOne) SetGroupID(v int64) *RedeemCodeUpdateOne {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@ -726,6 +772,12 @@ func (_u *RedeemCodeUpdateOne) sqlSave(ctx context.Context) (_node *RedeemCode,
|
||||
if _u.mutation.NotesCleared() {
|
||||
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.ValidityDays(); ok {
|
||||
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@ -1386,7 +1386,7 @@ func init() {
|
||||
// redeemcode.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
redeemcode.DefaultCreatedAt = redeemcodeDescCreatedAt.Default.(func() time.Time)
|
||||
// redeemcodeDescValidityDays is the schema descriptor for validity_days field.
|
||||
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
|
||||
redeemcodeDescValidityDays := redeemcodeFields[10].Descriptor()
|
||||
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
|
||||
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
|
||||
securitysecretMixin := schema.SecuritySecret{}.Mixin()
|
||||
@ -1722,12 +1722,24 @@ func init() {
|
||||
usagelogDescImageSize := usagelogFields[34].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescImageInputSize is the schema descriptor for image_input_size field.
|
||||
usagelogDescImageInputSize := usagelogFields[35].Descriptor()
|
||||
// usagelog.ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
|
||||
usagelog.ImageInputSizeValidator = usagelogDescImageInputSize.Validators[0].(func(string) error)
|
||||
// usagelogDescImageOutputSize is the schema descriptor for image_output_size field.
|
||||
usagelogDescImageOutputSize := usagelogFields[36].Descriptor()
|
||||
// usagelog.ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
|
||||
usagelog.ImageOutputSizeValidator = usagelogDescImageOutputSize.Validators[0].(func(string) error)
|
||||
// usagelogDescImageSizeSource is the schema descriptor for image_size_source field.
|
||||
usagelogDescImageSizeSource := usagelogFields[37].Descriptor()
|
||||
// usagelog.ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeSourceValidator = usagelogDescImageSizeSource.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[39].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[36].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[40].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@ -15,12 +15,13 @@ import (
|
||||
)
|
||||
|
||||
var authProviderTypes = map[string]struct{}{
|
||||
"email": {},
|
||||
"github": {},
|
||||
"google": {},
|
||||
"linuxdo": {},
|
||||
"oidc": {},
|
||||
"wechat": {},
|
||||
"email": {},
|
||||
"github": {},
|
||||
"google": {},
|
||||
"linuxdo": {},
|
||||
"oidc": {},
|
||||
"wechat": {},
|
||||
"dingtalk": {},
|
||||
}
|
||||
|
||||
func validateAuthProviderType(value string) error {
|
||||
|
||||
@ -83,7 +83,7 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
|
||||
require.Equal(t, 1, signupSource.Validators)
|
||||
|
||||
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
|
||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
|
||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk"} {
|
||||
require.NoError(t, validator(value))
|
||||
}
|
||||
require.Error(t, validator("unknown"))
|
||||
|
||||
@ -63,6 +63,10 @@ func (RedeemCode) Fields() []ent.Field {
|
||||
Immutable().
|
||||
Default(time.Now).
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("expires_at").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Int64("group_id").
|
||||
Optional().
|
||||
Nillable(),
|
||||
@ -90,5 +94,6 @@ func (RedeemCode) Indexes() []ent.Index {
|
||||
index.Fields("status"),
|
||||
index.Fields("used_by"),
|
||||
index.Fields("group_id"),
|
||||
index.Fields("expires_at"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,6 +134,21 @@ func (UsageLog) Fields() []ent.Field {
|
||||
MaxLen(10).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("image_input_size").
|
||||
MaxLen(32).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("image_output_size").
|
||||
MaxLen(32).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("image_size_source").
|
||||
MaxLen(16).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.JSON("image_size_breakdown", map[string]int{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||
field.Bool("cache_ttl_overridden").
|
||||
Default(false),
|
||||
|
||||
@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
|
||||
field.String("signup_source").
|
||||
Validate(func(value string) error {
|
||||
switch value {
|
||||
case "email", "linuxdo", "wechat", "oidc", "github", "google":
|
||||
case "email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google, dingtalk")
|
||||
}
|
||||
}).
|
||||
Default("email"),
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@ -92,6 +93,14 @@ type UsageLog struct {
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
// ImageSize holds the value of the "image_size" field.
|
||||
ImageSize *string `json:"image_size,omitempty"`
|
||||
// ImageInputSize holds the value of the "image_input_size" field.
|
||||
ImageInputSize *string `json:"image_input_size,omitempty"`
|
||||
// ImageOutputSize holds the value of the "image_output_size" field.
|
||||
ImageOutputSize *string `json:"image_output_size,omitempty"`
|
||||
// ImageSizeSource holds the value of the "image_size_source" field.
|
||||
ImageSizeSource *string `json:"image_size_source,omitempty"`
|
||||
// ImageSizeBreakdown holds the value of the "image_size_breakdown" field.
|
||||
ImageSizeBreakdown map[string]int `json:"image_size_breakdown,omitempty"`
|
||||
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
@ -179,13 +188,15 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
values[i] = new([]byte)
|
||||
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldImageInputSize, usagelog.FieldImageOutputSize, usagelog.FieldImageSizeSource:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@ -434,6 +445,35 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.ImageSize = new(string)
|
||||
*_m.ImageSize = value.String
|
||||
}
|
||||
case usagelog.FieldImageInputSize:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_input_size", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageInputSize = new(string)
|
||||
*_m.ImageInputSize = value.String
|
||||
}
|
||||
case usagelog.FieldImageOutputSize:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_output_size", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageOutputSize = new(string)
|
||||
*_m.ImageOutputSize = value.String
|
||||
}
|
||||
case usagelog.FieldImageSizeSource:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_size_source", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageSizeSource = new(string)
|
||||
*_m.ImageSizeSource = value.String
|
||||
}
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_size_breakdown", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.ImageSizeBreakdown); err != nil {
|
||||
return fmt.Errorf("unmarshal field image_size_breakdown: %w", err)
|
||||
}
|
||||
}
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
||||
@ -640,6 +680,24 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageInputSize; v != nil {
|
||||
builder.WriteString("image_input_size=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageOutputSize; v != nil {
|
||||
builder.WriteString("image_output_size=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageSizeSource; v != nil {
|
||||
builder.WriteString("image_size_source=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_size_breakdown=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageSizeBreakdown))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("cache_ttl_overridden=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@ -84,6 +84,14 @@ const (
|
||||
FieldImageCount = "image_count"
|
||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||
FieldImageSize = "image_size"
|
||||
// FieldImageInputSize holds the string denoting the image_input_size field in the database.
|
||||
FieldImageInputSize = "image_input_size"
|
||||
// FieldImageOutputSize holds the string denoting the image_output_size field in the database.
|
||||
FieldImageOutputSize = "image_output_size"
|
||||
// FieldImageSizeSource holds the string denoting the image_size_source field in the database.
|
||||
FieldImageSizeSource = "image_size_source"
|
||||
// FieldImageSizeBreakdown holds the string denoting the image_size_breakdown field in the database.
|
||||
FieldImageSizeBreakdown = "image_size_breakdown"
|
||||
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
||||
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
@ -175,6 +183,10 @@ var Columns = []string{
|
||||
FieldIPAddress,
|
||||
FieldImageCount,
|
||||
FieldImageSize,
|
||||
FieldImageInputSize,
|
||||
FieldImageOutputSize,
|
||||
FieldImageSizeSource,
|
||||
FieldImageSizeBreakdown,
|
||||
FieldCacheTTLOverridden,
|
||||
FieldCreatedAt,
|
||||
}
|
||||
@ -242,6 +254,12 @@ var (
|
||||
DefaultImageCount int
|
||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
ImageSizeValidator func(string) error
|
||||
// ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
|
||||
ImageInputSizeValidator func(string) error
|
||||
// ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
|
||||
ImageOutputSizeValidator func(string) error
|
||||
// ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
|
||||
ImageSizeSourceValidator func(string) error
|
||||
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
||||
DefaultCacheTTLOverridden bool
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
@ -431,6 +449,21 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageInputSize orders the results by the image_input_size field.
|
||||
func ByImageInputSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageInputSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageOutputSize orders the results by the image_output_size field.
|
||||
func ByImageOutputSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageOutputSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageSizeSource orders the results by the image_size_source field.
|
||||
func ByImageSizeSource(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageSizeSource, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
||||
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
||||
|
||||
@ -230,6 +230,21 @@ func ImageSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSize applies equality check predicate on the "image_input_size" field. It's identical to ImageInputSizeEQ.
|
||||
func ImageInputSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSize applies equality check predicate on the "image_output_size" field. It's identical to ImageOutputSizeEQ.
|
||||
func ImageOutputSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeSource applies equality check predicate on the "image_size_source" field. It's identical to ImageSizeSourceEQ.
|
||||
func ImageSizeSource(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
||||
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
@ -1900,6 +1915,241 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeEQ applies the EQ predicate on the "image_input_size" field.
|
||||
func ImageInputSizeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeNEQ applies the NEQ predicate on the "image_input_size" field.
|
||||
func ImageInputSizeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeIn applies the In predicate on the "image_input_size" field.
|
||||
func ImageInputSizeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageInputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageInputSizeNotIn applies the NotIn predicate on the "image_input_size" field.
|
||||
func ImageInputSizeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageInputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageInputSizeGT applies the GT predicate on the "image_input_size" field.
|
||||
func ImageInputSizeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeGTE applies the GTE predicate on the "image_input_size" field.
|
||||
func ImageInputSizeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeLT applies the LT predicate on the "image_input_size" field.
|
||||
func ImageInputSizeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeLTE applies the LTE predicate on the "image_input_size" field.
|
||||
func ImageInputSizeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeContains applies the Contains predicate on the "image_input_size" field.
|
||||
func ImageInputSizeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeHasPrefix applies the HasPrefix predicate on the "image_input_size" field.
|
||||
func ImageInputSizeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeHasSuffix applies the HasSuffix predicate on the "image_input_size" field.
|
||||
func ImageInputSizeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeIsNil applies the IsNil predicate on the "image_input_size" field.
|
||||
func ImageInputSizeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageInputSize))
|
||||
}
|
||||
|
||||
// ImageInputSizeNotNil applies the NotNil predicate on the "image_input_size" field.
|
||||
func ImageInputSizeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageInputSize))
|
||||
}
|
||||
|
||||
// ImageInputSizeEqualFold applies the EqualFold predicate on the "image_input_size" field.
|
||||
func ImageInputSizeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeContainsFold applies the ContainsFold predicate on the "image_input_size" field.
|
||||
func ImageInputSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeEQ applies the EQ predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeNEQ applies the NEQ predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeIn applies the In predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageOutputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageOutputSizeNotIn applies the NotIn predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageOutputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageOutputSizeGT applies the GT predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeGTE applies the GTE predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeLT applies the LT predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeLTE applies the LTE predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeContains applies the Contains predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeHasPrefix applies the HasPrefix predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeHasSuffix applies the HasSuffix predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeIsNil applies the IsNil predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageOutputSize))
|
||||
}
|
||||
|
||||
// ImageOutputSizeNotNil applies the NotNil predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageOutputSize))
|
||||
}
|
||||
|
||||
// ImageOutputSizeEqualFold applies the EqualFold predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeContainsFold applies the ContainsFold predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceEQ applies the EQ predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceNEQ applies the NEQ predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceIn applies the In predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageSizeSource, vs...))
|
||||
}
|
||||
|
||||
// ImageSizeSourceNotIn applies the NotIn predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageSizeSource, vs...))
|
||||
}
|
||||
|
||||
// ImageSizeSourceGT applies the GT predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceGTE applies the GTE predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceLT applies the LT predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceLTE applies the LTE predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceContains applies the Contains predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceHasPrefix applies the HasPrefix predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceHasSuffix applies the HasSuffix predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceIsNil applies the IsNil predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeSource))
|
||||
}
|
||||
|
||||
// ImageSizeSourceNotNil applies the NotNil predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeSource))
|
||||
}
|
||||
|
||||
// ImageSizeSourceEqualFold applies the EqualFold predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceContainsFold applies the ContainsFold predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeBreakdownIsNil applies the IsNil predicate on the "image_size_breakdown" field.
|
||||
func ImageSizeBreakdownIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeBreakdown))
|
||||
}
|
||||
|
||||
// ImageSizeBreakdownNotNil applies the NotNil predicate on the "image_size_breakdown" field.
|
||||
func ImageSizeBreakdownNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeBreakdown))
|
||||
}
|
||||
|
||||
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
||||
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
|
||||
@ -477,6 +477,54 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (_c *UsageLogCreate) SetImageInputSize(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageInputSize(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageInputSize(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageInputSize(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (_c *UsageLogCreate) SetImageOutputSize(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageOutputSize(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageOutputSize(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageOutputSize(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (_c *UsageLogCreate) SetImageSizeSource(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageSizeSource(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageSizeSource(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageSizeSource(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (_c *UsageLogCreate) SetImageSizeBreakdown(v map[string]int) *UsageLogCreate {
|
||||
_c.mutation.SetImageSizeBreakdown(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
||||
_c.mutation.SetCacheTTLOverridden(v)
|
||||
@ -754,6 +802,21 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageInputSize(); ok {
|
||||
if err := usagelog.ImageInputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageOutputSize(); ok {
|
||||
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageSizeSource(); ok {
|
||||
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
||||
}
|
||||
@ -916,6 +979,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||
_node.ImageSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageInputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
|
||||
_node.ImageInputSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageOutputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
|
||||
_node.ImageOutputSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageSizeSource(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
|
||||
_node.ImageSizeSource = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageSizeBreakdown(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
|
||||
_node.ImageSizeBreakdown = value
|
||||
}
|
||||
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
_node.CacheTTLOverridden = value
|
||||
@ -1679,6 +1758,78 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (u *UsageLogUpsert) SetImageInputSize(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageInputSize, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageInputSize() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageInputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (u *UsageLogUpsert) ClearImageInputSize() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageInputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (u *UsageLogUpsert) SetImageOutputSize(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageOutputSize, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageOutputSize() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageOutputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (u *UsageLogUpsert) ClearImageOutputSize() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageOutputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (u *UsageLogUpsert) SetImageSizeSource(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageSizeSource, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageSizeSource() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageSizeSource)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (u *UsageLogUpsert) ClearImageSizeSource() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageSizeSource)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsert) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageSizeBreakdown, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageSizeBreakdown() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageSizeBreakdown)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsert) ClearImageSizeBreakdown() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageSizeBreakdown)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
||||
@ -2457,6 +2608,90 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (u *UsageLogUpsertOne) SetImageInputSize(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageInputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageInputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageInputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (u *UsageLogUpsertOne) SetImageOutputSize(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageOutputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageOutputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageOutputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (u *UsageLogUpsertOne) SetImageSizeSource(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeSource(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageSizeSource() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageSizeSource() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeBreakdown(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageSizeBreakdown() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageSizeBreakdown() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@ -3403,6 +3638,90 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageInputSize(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageInputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageInputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageInputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageOutputSize(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageOutputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageOutputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageOutputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageSizeSource(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeSource(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageSizeSource() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageSizeSource() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeBreakdown(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageSizeBreakdown() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageSizeBreakdown() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@ -739,6 +739,78 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (_u *UsageLogUpdate) SetImageInputSize(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageInputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageInputSize(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageInputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (_u *UsageLogUpdate) ClearImageInputSize() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageInputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (_u *UsageLogUpdate) SetImageOutputSize(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageOutputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageOutputSize(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageOutputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (_u *UsageLogUpdate) ClearImageOutputSize() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageOutputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (_u *UsageLogUpdate) SetImageSizeSource(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageSizeSource(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageSizeSource(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageSizeSource(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (_u *UsageLogUpdate) ClearImageSizeSource() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageSizeSource()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdate) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdate {
|
||||
_u.mutation.SetImageSizeBreakdown(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdate) ClearImageSizeBreakdown() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageSizeBreakdown()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
@ -892,6 +964,21 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageInputSize(); ok {
|
||||
if err := usagelog.ImageInputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@ -1099,6 +1186,30 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageInputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageInputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageOutputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeSourceCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeBreakdownCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
@ -1974,6 +2085,78 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageInputSize(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageInputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageInputSize(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageInputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageInputSize() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageInputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageOutputSize(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageOutputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageOutputSize(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageOutputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageOutputSize() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageOutputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageSizeSource(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageSizeSource(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageSizeSource(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageSizeSource(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageSizeSource() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageSizeSource()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageSizeBreakdown(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageSizeBreakdown() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageSizeBreakdown()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
@ -2140,6 +2323,21 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageInputSize(); ok {
|
||||
if err := usagelog.ImageInputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@ -2364,6 +2562,30 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageInputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageInputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageOutputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeSourceCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeBreakdownCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
|
||||
@ -216,6 +216,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@ -249,6 +251,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@ -278,6 +282,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@ -310,6 +316,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@ -72,6 +72,7 @@ type Config struct {
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
|
||||
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
||||
DingTalk DingTalkConnectConfig `mapstructure:"dingtalk_connect"`
|
||||
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
|
||||
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
@ -242,6 +243,47 @@ type OIDCConnectConfig struct {
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type DingTalkConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"`
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
|
||||
|
||||
// 平台底座 + 业务行为
|
||||
DingTalkAppKind string `mapstructure:"dingtalk_app_kind"` // 仅 "internal_app"(V4 fail-closed)
|
||||
AppType string `mapstructure:"app_type"` // "public" (default) | "internal"
|
||||
|
||||
// Corp 限定(none | internal_only)
|
||||
CorpRestrictionPolicy string `mapstructure:"corp_restriction_policy"`
|
||||
InternalCorpID string `mapstructure:"internal_corp_id"`
|
||||
BypassRegistration bool `mapstructure:"bypass_registration"`
|
||||
SyncCorpEmail bool `mapstructure:"sync_corp_email"`
|
||||
SyncDisplayName bool `mapstructure:"sync_display_name"`
|
||||
SyncDept bool `mapstructure:"sync_dept"`
|
||||
SyncCorpEmailAttrKey string `mapstructure:"sync_corp_email_attr_key"`
|
||||
SyncDisplayNameAttrKey string `mapstructure:"sync_display_name_attr_key"`
|
||||
SyncDeptAttrKey string `mapstructure:"sync_dept_attr_key"`
|
||||
SyncCorpEmailAttrName string `mapstructure:"sync_corp_email_attr_name"`
|
||||
SyncDisplayNameAttrName string `mapstructure:"sync_display_name_attr_name"`
|
||||
SyncDeptAttrName string `mapstructure:"sync_dept_attr_name"`
|
||||
|
||||
// 邮箱 + Username
|
||||
RequireEmail bool `mapstructure:"require_email"`
|
||||
UsernameOverwritePolicy string `mapstructure:"username_overwrite_policy"`
|
||||
|
||||
// Attribute(私有版扩展点;开源版仅声明)
|
||||
UsernameAttributeKey string `mapstructure:"username_attribute_key"`
|
||||
EnableAttributeMatching bool `mapstructure:"enable_attribute_matching"`
|
||||
EnableAttributeSync bool `mapstructure:"enable_attribute_sync"`
|
||||
AttributeSyncFields []string `mapstructure:"attribute_sync_fields"`
|
||||
AttributeSyncOverwritePolicy string `mapstructure:"attribute_sync_overwrite_policy"`
|
||||
}
|
||||
|
||||
type EmailOAuthProviderConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
@ -1536,6 +1578,19 @@ func setDefaults() {
|
||||
viper.SetDefault("oidc_connect.userinfo_id_path", "")
|
||||
viper.SetDefault("oidc_connect.userinfo_username_path", "")
|
||||
|
||||
// DingTalk Connect OAuth 登录
|
||||
viper.SetDefault("dingtalk_connect.enabled", false)
|
||||
viper.SetDefault("dingtalk_connect.authorize_url", "https://login.dingtalk.com/oauth2/auth")
|
||||
viper.SetDefault("dingtalk_connect.token_url", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken")
|
||||
viper.SetDefault("dingtalk_connect.userinfo_url", "https://api.dingtalk.com/v1.0/contact/users/me")
|
||||
viper.SetDefault("dingtalk_connect.scopes", "openid")
|
||||
viper.SetDefault("dingtalk_connect.frontend_redirect_url", "/auth/dingtalk/callback")
|
||||
viper.SetDefault("dingtalk_connect.dingtalk_app_kind", "internal_app")
|
||||
viper.SetDefault("dingtalk_connect.app_type", "public")
|
||||
viper.SetDefault("dingtalk_connect.corp_restriction_policy", "none")
|
||||
viper.SetDefault("dingtalk_connect.require_email", true)
|
||||
viper.SetDefault("dingtalk_connect.username_overwrite_policy", "if_empty")
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
@ -2608,6 +2663,9 @@ func (c *Config) Validate() error {
|
||||
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
|
||||
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
|
||||
}
|
||||
if err := ValidateDingTalkConfig(c.DingTalk); err != nil {
|
||||
return fmt.Errorf("dingtalk_connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
30
backend/internal/config/validate_dingtalk.go
Normal file
30
backend/internal/config/validate_dingtalk.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Package config 包含钉钉连接配置的校验逻辑。
|
||||
//
|
||||
// internal_only 模式安全模型(方案 A):
|
||||
// 不再要求 admin 填写 InternalCorpID 做二次 corpID 比对。
|
||||
// 安全边界由钉钉"企业内部应用"类型本身保证——只有应用所属企业的员工才能完成 OAuth,
|
||||
// 因此 ValidateDingTalkConfig 只要求 app_type=internal(V1),不再要求 InternalCorpID 非空(原 V3 已删除)。
|
||||
// InternalCorpID 字段保留,admin 可选填;若填写,checkDingTalkCorpAllowed 不会使用它做约束。
|
||||
package config
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDingTalkV1AppTypeMismatch = errors.New("dingtalk: internal_only requires app_type=internal")
|
||||
ErrDingTalkV4InvalidAppKind = errors.New("dingtalk: dingtalk_app_kind must be internal_app")
|
||||
)
|
||||
|
||||
func ValidateDingTalkConfig(cfg DingTalkConnectConfig) error {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
if cfg.DingTalkAppKind != "internal_app" {
|
||||
return ErrDingTalkV4InvalidAppKind
|
||||
}
|
||||
if cfg.CorpRestrictionPolicy == "internal_only" {
|
||||
if cfg.AppType != "internal" {
|
||||
return ErrDingTalkV1AppTypeMismatch
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
53
backend/internal/config/validate_dingtalk_test.go
Normal file
53
backend/internal/config/validate_dingtalk_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateDingTalkConfig_Disabled_Skip(t *testing.T) {
|
||||
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{Enabled: false}))
|
||||
}
|
||||
|
||||
func TestValidateDingTalkConfig_V4_DingTalkAppKind(t *testing.T) {
|
||||
err := ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "third_party_enterprise_app",
|
||||
CorpRestrictionPolicy: "none",
|
||||
})
|
||||
require.ErrorIs(t, err, ErrDingTalkV4InvalidAppKind)
|
||||
}
|
||||
|
||||
func TestValidateDingTalkConfig_V1_InternalOnlyRequiresInternalAppType(t *testing.T) {
|
||||
err := ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "public",
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "dingABC",
|
||||
})
|
||||
require.ErrorIs(t, err, ErrDingTalkV1AppTypeMismatch)
|
||||
}
|
||||
|
||||
// TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A:
|
||||
// internal_only 策略下,InternalCorpID="" 应通过校验(企业隔离由钉钉 AppType=internal 保证)。
|
||||
func TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
|
||||
err := ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "internal",
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateDingTalkConfig_HappyPath_None(t *testing.T) {
|
||||
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "public",
|
||||
CorpRestrictionPolicy: "none",
|
||||
}))
|
||||
}
|
||||
@ -43,6 +43,9 @@ type DataProxy struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// DataAccount 是管理员显式备份导出使用的账号结构,故意不走 dto.Account 的脱敏路径,
|
||||
// Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本,
|
||||
// 应新增独立结构而非修改这里。
|
||||
type DataAccount struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
|
||||
@ -1994,6 +1994,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// SyncUpstreamModels handles syncing live supported models from an account's upstream.
|
||||
// POST /api/v1/admin/accounts/:id/models/sync-upstream
|
||||
func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
if h.accountTestService == nil {
|
||||
response.InternalError(c, "Account test service is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
var syncErr *service.UpstreamModelSyncError
|
||||
if errors.As(err, &syncErr) {
|
||||
switch syncErr.Kind {
|
||||
case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported:
|
||||
response.BadRequest(c, syncErr.SafeMessage())
|
||||
default:
|
||||
slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind)
|
||||
response.Error(c, http.StatusBadGateway, syncErr.SafeMessage())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Warn("sync_upstream_models_failed", "account_id", accountID)
|
||||
response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"models": models})
|
||||
}
|
||||
|
||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||
|
||||
@ -3,10 +3,14 @@ package admin
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -33,6 +37,39 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
type syncUpstreamHTTPUpstream struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
if u.err != nil {
|
||||
return nil, u.err
|
||||
}
|
||||
return u.resp, nil
|
||||
}
|
||||
|
||||
func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
accountTestSvc := service.NewAccountTestService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
upstream,
|
||||
&config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
nil,
|
||||
)
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
@ -103,3 +140,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "openai-apikey-missing-key",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "No OpenAI API key is available")
|
||||
}
|
||||
|
||||
func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "openai-apikey-upstream-error",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "openai-key",
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
|
||||
}}
|
||||
router := setupSyncUpstreamModelsRouter(svc, upstream)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502")
|
||||
require.NotContains(t, rec.Body.String(), "SECRET_TOKEN")
|
||||
}
|
||||
|
||||
@ -546,9 +546,14 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// cacheKey 必须包含当日日期,否则跨午夜后 30s 内会复用昨天的 "today_*" 结果。
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
V int `json:"v"`
|
||||
Day string `json:"day"`
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
V: 2, // bump 当响应结构变化(如加入 by_platform 时)
|
||||
Day: timezone.Today().Format("2006-01-02"),
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@ -33,23 +34,51 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
|
||||
|
||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
|
||||
Notes string `json:"notes"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
|
||||
}
|
||||
|
||||
func resolveRedeemCodeExpiresAt(expiresAt *time.Time, expiresInDays *int) (*time.Time, error) {
|
||||
if expiresAt != nil && expiresInDays != nil {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRY_CONFLICT", "expires_at and expires_in_days cannot both be set")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if expiresInDays != nil {
|
||||
if *expiresInDays <= 0 {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_IN_DAYS_INVALID", "expires_in_days must be greater than zero")
|
||||
}
|
||||
expires := now.AddDate(0, 0, *expiresInDays)
|
||||
return &expires, nil
|
||||
}
|
||||
if expiresAt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
expires := expiresAt.UTC()
|
||||
if !expires.After(now) {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
|
||||
}
|
||||
return &expires, nil
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@ -107,6 +136,12 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
@ -114,6 +149,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
@ -158,6 +194,12 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
if err == nil {
|
||||
@ -175,6 +217,7 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
@ -199,6 +242,9 @@ func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, exis
|
||||
}
|
||||
|
||||
// If previous run created the code but crashed before redeem, redeem it now.
|
||||
if existing.IsExpired() {
|
||||
return nil, service.ErrRedeemCodeExpired
|
||||
}
|
||||
if existing.CanUse() {
|
||||
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
|
||||
if err == nil {
|
||||
@ -321,7 +367,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "expires_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -340,6 +386,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
expiresAt := ""
|
||||
if code.ExpiresAt != nil {
|
||||
expiresAt = code.ExpiresAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
@ -349,6 +399,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
usedBy,
|
||||
usedByEmail,
|
||||
usedAt,
|
||||
expiresAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -139,3 +140,33 @@ func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_FromDays(t *testing.T) {
|
||||
days := 3
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expiresAt)
|
||||
require.WithinDuration(t, time.Now().UTC().AddDate(0, 0, days), *expiresAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_RejectsPastAbsoluteTime(t *testing.T) {
|
||||
past := time.Now().UTC().Add(-time.Minute)
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(&past, nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, expiresAt)
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_RejectsNonPositiveDays(t *testing.T) {
|
||||
days := 0
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, expiresAt)
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_RejectsConflictingInputs(t *testing.T) {
|
||||
future := time.Now().UTC().Add(time.Hour)
|
||||
days := 3
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(&future, &days)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, expiresAt)
|
||||
}
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@ -60,10 +62,11 @@ type SettingHandler struct {
|
||||
opsService *service.OpsService
|
||||
paymentConfigService *service.PaymentConfigService
|
||||
paymentService *service.PaymentService
|
||||
userAttributeService *service.UserAttributeService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
@ -71,6 +74,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
|
||||
opsService: opsService,
|
||||
paymentConfigService: paymentConfigService,
|
||||
paymentService: paymentService,
|
||||
userAttributeService: userAttributeService,
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,6 +139,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||
DingTalkConnectEnabled: settings.DingTalkConnectEnabled,
|
||||
DingTalkConnectClientID: settings.DingTalkConnectClientID,
|
||||
DingTalkConnectClientSecretConfigured: settings.DingTalkConnectClientSecretConfigured,
|
||||
DingTalkConnectRedirectURL: settings.DingTalkConnectRedirectURL,
|
||||
DingTalkConnectCorpRestrictionPolicy: settings.DingTalkConnectCorpRestrictionPolicy,
|
||||
DingTalkConnectInternalCorpID: settings.DingTalkConnectInternalCorpID,
|
||||
DingTalkConnectBypassRegistration: settings.DingTalkConnectBypassRegistration,
|
||||
DingTalkConnectSyncCorpEmail: settings.DingTalkConnectSyncCorpEmail,
|
||||
DingTalkConnectSyncDisplayName: settings.DingTalkConnectSyncDisplayName,
|
||||
DingTalkConnectSyncDept: settings.DingTalkConnectSyncDept,
|
||||
DingTalkConnectSyncCorpEmailAttrKey: settings.DingTalkConnectSyncCorpEmailAttrKey,
|
||||
DingTalkConnectSyncDisplayNameAttrKey: settings.DingTalkConnectSyncDisplayNameAttrKey,
|
||||
DingTalkConnectSyncDeptAttrKey: settings.DingTalkConnectSyncDeptAttrKey,
|
||||
DingTalkConnectSyncCorpEmailAttrName: settings.DingTalkConnectSyncCorpEmailAttrName,
|
||||
DingTalkConnectSyncDisplayNameAttrName: settings.DingTalkConnectSyncDisplayNameAttrName,
|
||||
DingTalkConnectSyncDeptAttrName: settings.DingTalkConnectSyncDeptAttrName,
|
||||
WeChatConnectEnabled: settings.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: settings.WeChatConnectAppID,
|
||||
WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
|
||||
@ -376,6 +396,24 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// DingTalk Connect OAuth 登录
|
||||
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
|
||||
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
|
||||
DingTalkConnectClientSecret string `json:"dingtalk_connect_client_secret"`
|
||||
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
|
||||
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
|
||||
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
|
||||
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
|
||||
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
|
||||
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
|
||||
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
|
||||
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
|
||||
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
|
||||
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
|
||||
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
|
||||
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
|
||||
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
|
||||
|
||||
// WeChat Connect OAuth 登录
|
||||
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
|
||||
WeChatConnectAppID string `json:"wechat_connect_app_id"`
|
||||
@ -446,45 +484,50 @@ type UpdateSettingsRequest struct {
|
||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
|
||||
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
|
||||
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
|
||||
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
|
||||
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
|
||||
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
|
||||
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
|
||||
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
|
||||
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
|
||||
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
|
||||
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
|
||||
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
|
||||
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
|
||||
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
|
||||
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
|
||||
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
|
||||
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
|
||||
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
|
||||
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
|
||||
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
|
||||
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
|
||||
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
|
||||
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
|
||||
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
|
||||
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
|
||||
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
|
||||
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
|
||||
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
|
||||
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||
AuthSourceDefaultDingTalkBalance *float64 `json:"auth_source_default_dingtalk_balance"`
|
||||
AuthSourceDefaultDingTalkConcurrency *int `json:"auth_source_default_dingtalk_concurrency"`
|
||||
AuthSourceDefaultDingTalkSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_dingtalk_subscriptions"`
|
||||
AuthSourceDefaultDingTalkGrantOnSignup *bool `json:"auth_source_default_dingtalk_grant_on_signup"`
|
||||
AuthSourceDefaultDingTalkGrantOnFirstBind *bool `json:"auth_source_default_dingtalk_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@ -661,6 +704,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
|
||||
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
|
||||
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
|
||||
req.AuthSourceDefaultDingTalkSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultDingTalkSubscriptions)
|
||||
|
||||
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
|
||||
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
|
||||
@ -777,6 +821,100 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// DingTalk Connect 参数验证
|
||||
// 防御性:任何写入路径上把已废弃的 corp_restriction_policy=whitelist 入参 coerce 为 none,
|
||||
// 避免任何直连 admin API 的客户端把死值写回 DB(前端 UI 已无此选项)。
|
||||
req.DingTalkConnectCorpRestrictionPolicy = service.CoerceDingTalkCorpPolicyForWrite(req.DingTalkConnectCorpRestrictionPolicy)
|
||||
|
||||
if req.DingTalkConnectEnabled {
|
||||
req.DingTalkConnectClientID = strings.TrimSpace(req.DingTalkConnectClientID)
|
||||
req.DingTalkConnectClientSecret = strings.TrimSpace(req.DingTalkConnectClientSecret)
|
||||
req.DingTalkConnectRedirectURL = strings.TrimSpace(req.DingTalkConnectRedirectURL)
|
||||
req.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(req.DingTalkConnectCorpRestrictionPolicy)
|
||||
req.DingTalkConnectInternalCorpID = strings.TrimSpace(req.DingTalkConnectInternalCorpID)
|
||||
|
||||
if req.DingTalkConnectClientID == "" {
|
||||
response.BadRequest(c, "DingTalk Client ID is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.DingTalkConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "DingTalk Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.DingTalkConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "DingTalk Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||
if req.DingTalkConnectClientSecret == "" {
|
||||
if previousSettings.DingTalkConnectClientSecret == "" {
|
||||
response.BadRequest(c, "DingTalk Client Secret is required when enabled")
|
||||
return
|
||||
}
|
||||
req.DingTalkConnectClientSecret = previousSettings.DingTalkConnectClientSecret
|
||||
}
|
||||
|
||||
// Corp 策略校验(V1/V4 fail-closed)
|
||||
dingTalkCfg := config.DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app", // 硬编码:settings 层仅支持 internal_app
|
||||
AppType: "internal", // 对于 internal_only 策略的默认值
|
||||
CorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
|
||||
InternalCorpID: req.DingTalkConnectInternalCorpID,
|
||||
}
|
||||
// 若未填 corp_restriction_policy,保留已有配置
|
||||
if dingTalkCfg.CorpRestrictionPolicy == "" {
|
||||
dingTalkCfg.CorpRestrictionPolicy = previousSettings.DingTalkConnectCorpRestrictionPolicy
|
||||
}
|
||||
// 对于 internal_only 策略,app_type 必须为 internal(V1 校验)
|
||||
if dingTalkCfg.CorpRestrictionPolicy == "internal_only" {
|
||||
dingTalkCfg.AppType = "internal"
|
||||
} else {
|
||||
dingTalkCfg.AppType = "public"
|
||||
}
|
||||
if err := config.ValidateDingTalkConfig(dingTalkCfg); err != nil {
|
||||
response.ErrorWithDetails(c, http.StatusBadRequest, err.Error(), mapDingTalkValidateError(err), nil)
|
||||
return
|
||||
}
|
||||
|
||||
// bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制为 false,
|
||||
// 防止 admin 在切换 policy 时把 bypass 残留在 DB 中(前端 UI 也已隐藏该开关)。
|
||||
if dingTalkCfg.CorpRestrictionPolicy != "internal_only" {
|
||||
req.DingTalkConnectBypassRegistration = false
|
||||
// 身份同步三开关同理:仅 internal_only 模式下有意义,其它策略强制 false。
|
||||
req.DingTalkConnectSyncCorpEmail = false
|
||||
req.DingTalkConnectSyncDisplayName = false
|
||||
req.DingTalkConnectSyncDept = false
|
||||
}
|
||||
// 身份同步目标 attr key:trimSpace + 空值 fallback 到默认值
|
||||
req.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrKey)
|
||||
if req.DingTalkConnectSyncCorpEmailAttrKey == "" {
|
||||
req.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email"
|
||||
}
|
||||
req.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrKey)
|
||||
if req.DingTalkConnectSyncDisplayNameAttrKey == "" {
|
||||
req.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name"
|
||||
}
|
||||
req.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrKey)
|
||||
if req.DingTalkConnectSyncDeptAttrKey == "" {
|
||||
req.DingTalkConnectSyncDeptAttrKey = "dingtalk_department"
|
||||
}
|
||||
// 身份同步目标 attr 显示名称:trim + 空值 fallback 到默认中文名
|
||||
req.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrName)
|
||||
if req.DingTalkConnectSyncCorpEmailAttrName == "" {
|
||||
req.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱"
|
||||
}
|
||||
req.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrName)
|
||||
if req.DingTalkConnectSyncDisplayNameAttrName == "" {
|
||||
req.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名"
|
||||
}
|
||||
req.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrName)
|
||||
if req.DingTalkConnectSyncDeptAttrName == "" {
|
||||
req.DingTalkConnectSyncDeptAttrName = "钉钉部门"
|
||||
}
|
||||
}
|
||||
|
||||
if req.WeChatConnectEnabled {
|
||||
req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
|
||||
req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
|
||||
@ -1272,113 +1410,129 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
WeChatConnectEnabled: req.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: req.WeChatConnectAppID,
|
||||
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
|
||||
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
|
||||
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
|
||||
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
|
||||
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
|
||||
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
|
||||
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
|
||||
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
|
||||
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
|
||||
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
|
||||
WeChatConnectMode: req.WeChatConnectMode,
|
||||
WeChatConnectScopes: req.WeChatConnectScopes,
|
||||
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
|
||||
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
|
||||
OIDCConnectEnabled: req.OIDCConnectEnabled,
|
||||
OIDCConnectProviderName: req.OIDCConnectProviderName,
|
||||
OIDCConnectClientID: req.OIDCConnectClientID,
|
||||
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
|
||||
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
|
||||
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
|
||||
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
|
||||
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
|
||||
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
|
||||
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
|
||||
OIDCConnectScopes: req.OIDCConnectScopes,
|
||||
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
|
||||
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
|
||||
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
|
||||
OIDCConnectUsePKCE: oidcUsePKCE,
|
||||
OIDCConnectValidateIDToken: oidcValidateIDToken,
|
||||
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
|
||||
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
|
||||
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
|
||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
TableDefaultPageSize: req.TableDefaultPageSize,
|
||||
TablePageSizeOptions: req.TablePageSizeOptions,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
DingTalkConnectEnabled: req.DingTalkConnectEnabled,
|
||||
DingTalkConnectClientID: req.DingTalkConnectClientID,
|
||||
DingTalkConnectClientSecret: req.DingTalkConnectClientSecret,
|
||||
DingTalkConnectRedirectURL: req.DingTalkConnectRedirectURL,
|
||||
DingTalkConnectCorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
|
||||
DingTalkConnectInternalCorpID: req.DingTalkConnectInternalCorpID,
|
||||
DingTalkConnectBypassRegistration: req.DingTalkConnectBypassRegistration,
|
||||
DingTalkConnectSyncCorpEmail: req.DingTalkConnectSyncCorpEmail,
|
||||
DingTalkConnectSyncDisplayName: req.DingTalkConnectSyncDisplayName,
|
||||
DingTalkConnectSyncDept: req.DingTalkConnectSyncDept,
|
||||
DingTalkConnectSyncCorpEmailAttrKey: req.DingTalkConnectSyncCorpEmailAttrKey,
|
||||
DingTalkConnectSyncDisplayNameAttrKey: req.DingTalkConnectSyncDisplayNameAttrKey,
|
||||
DingTalkConnectSyncDeptAttrKey: req.DingTalkConnectSyncDeptAttrKey,
|
||||
DingTalkConnectSyncCorpEmailAttrName: req.DingTalkConnectSyncCorpEmailAttrName,
|
||||
DingTalkConnectSyncDisplayNameAttrName: req.DingTalkConnectSyncDisplayNameAttrName,
|
||||
DingTalkConnectSyncDeptAttrName: req.DingTalkConnectSyncDeptAttrName,
|
||||
WeChatConnectEnabled: req.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: req.WeChatConnectAppID,
|
||||
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
|
||||
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
|
||||
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
|
||||
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
|
||||
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
|
||||
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
|
||||
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
|
||||
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
|
||||
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
|
||||
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
|
||||
WeChatConnectMode: req.WeChatConnectMode,
|
||||
WeChatConnectScopes: req.WeChatConnectScopes,
|
||||
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
|
||||
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
|
||||
OIDCConnectEnabled: req.OIDCConnectEnabled,
|
||||
OIDCConnectProviderName: req.OIDCConnectProviderName,
|
||||
OIDCConnectClientID: req.OIDCConnectClientID,
|
||||
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
|
||||
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
|
||||
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
|
||||
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
|
||||
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
|
||||
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
|
||||
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
|
||||
OIDCConnectScopes: req.OIDCConnectScopes,
|
||||
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
|
||||
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
|
||||
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
|
||||
OIDCConnectUsePKCE: oidcUsePKCE,
|
||||
OIDCConnectValidateIDToken: oidcValidateIDToken,
|
||||
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
|
||||
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
|
||||
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
|
||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
TableDefaultPageSize: req.TableDefaultPageSize,
|
||||
TablePageSizeOptions: req.TablePageSizeOptions,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@ -1574,6 +1728,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||
},
|
||||
DingTalk: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultDingTalkConcurrency, previousAuthSourceDefaults.DingTalk.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
||||
@ -1632,6 +1793,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
h.ensureDingTalkSyncAttributes(c.Request.Context(), updatedSettings)
|
||||
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -1682,6 +1844,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||
DingTalkConnectEnabled: updatedSettings.DingTalkConnectEnabled,
|
||||
DingTalkConnectClientID: updatedSettings.DingTalkConnectClientID,
|
||||
DingTalkConnectClientSecretConfigured: updatedSettings.DingTalkConnectClientSecretConfigured,
|
||||
DingTalkConnectRedirectURL: updatedSettings.DingTalkConnectRedirectURL,
|
||||
DingTalkConnectCorpRestrictionPolicy: updatedSettings.DingTalkConnectCorpRestrictionPolicy,
|
||||
DingTalkConnectInternalCorpID: updatedSettings.DingTalkConnectInternalCorpID,
|
||||
DingTalkConnectBypassRegistration: updatedSettings.DingTalkConnectBypassRegistration,
|
||||
DingTalkConnectSyncCorpEmail: updatedSettings.DingTalkConnectSyncCorpEmail,
|
||||
DingTalkConnectSyncDisplayName: updatedSettings.DingTalkConnectSyncDisplayName,
|
||||
DingTalkConnectSyncDept: updatedSettings.DingTalkConnectSyncDept,
|
||||
DingTalkConnectSyncCorpEmailAttrKey: updatedSettings.DingTalkConnectSyncCorpEmailAttrKey,
|
||||
DingTalkConnectSyncDisplayNameAttrKey: updatedSettings.DingTalkConnectSyncDisplayNameAttrKey,
|
||||
DingTalkConnectSyncDeptAttrKey: updatedSettings.DingTalkConnectSyncDeptAttrKey,
|
||||
DingTalkConnectSyncCorpEmailAttrName: updatedSettings.DingTalkConnectSyncCorpEmailAttrName,
|
||||
DingTalkConnectSyncDisplayNameAttrName: updatedSettings.DingTalkConnectSyncDisplayNameAttrName,
|
||||
DingTalkConnectSyncDeptAttrName: updatedSettings.DingTalkConnectSyncDeptAttrName,
|
||||
WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
|
||||
WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
|
||||
@ -1822,6 +2000,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
// hasPaymentFields returns true if any payment-related field was explicitly provided.
|
||||
// mapDingTalkValidateError maps ValidateDingTalkConfig errors to machine-readable reason codes.
|
||||
func mapDingTalkValidateError(err error) string {
|
||||
switch {
|
||||
case errors.Is(err, config.ErrDingTalkV1AppTypeMismatch):
|
||||
return "dingtalk_apptype_mismatch"
|
||||
case errors.Is(err, config.ErrDingTalkV4InvalidAppKind):
|
||||
return "dingtalk_app_kind_invalid"
|
||||
default:
|
||||
return "dingtalk_corp_config_invalid"
|
||||
}
|
||||
}
|
||||
|
||||
func hasPaymentFields(req UpdateSettingsRequest) bool {
|
||||
return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
|
||||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
|
||||
@ -1935,6 +2125,45 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||
}
|
||||
if before.DingTalkConnectEnabled != after.DingTalkConnectEnabled {
|
||||
changed = append(changed, "dingtalk_connect_enabled")
|
||||
}
|
||||
if before.DingTalkConnectClientID != after.DingTalkConnectClientID {
|
||||
changed = append(changed, "dingtalk_connect_client_id")
|
||||
}
|
||||
if req.DingTalkConnectClientSecret != "" {
|
||||
changed = append(changed, "dingtalk_connect_client_secret")
|
||||
}
|
||||
if before.DingTalkConnectRedirectURL != after.DingTalkConnectRedirectURL {
|
||||
changed = append(changed, "dingtalk_connect_redirect_url")
|
||||
}
|
||||
if before.DingTalkConnectCorpRestrictionPolicy != after.DingTalkConnectCorpRestrictionPolicy {
|
||||
changed = append(changed, "dingtalk_connect_corp_restriction_policy")
|
||||
}
|
||||
if before.DingTalkConnectInternalCorpID != after.DingTalkConnectInternalCorpID {
|
||||
changed = append(changed, "dingtalk_connect_internal_corp_id")
|
||||
}
|
||||
if before.DingTalkConnectBypassRegistration != after.DingTalkConnectBypassRegistration {
|
||||
changed = append(changed, "dingtalk_connect_bypass_registration")
|
||||
}
|
||||
if before.DingTalkConnectSyncCorpEmail != after.DingTalkConnectSyncCorpEmail {
|
||||
changed = append(changed, "dingtalk_connect_sync_corp_email")
|
||||
}
|
||||
if before.DingTalkConnectSyncDisplayName != after.DingTalkConnectSyncDisplayName {
|
||||
changed = append(changed, "dingtalk_connect_sync_display_name")
|
||||
}
|
||||
if before.DingTalkConnectSyncDept != after.DingTalkConnectSyncDept {
|
||||
changed = append(changed, "dingtalk_connect_sync_dept")
|
||||
}
|
||||
if before.DingTalkConnectSyncCorpEmailAttrKey != after.DingTalkConnectSyncCorpEmailAttrKey {
|
||||
changed = append(changed, "dingtalk_connect_sync_corp_email_attr_key")
|
||||
}
|
||||
if before.DingTalkConnectSyncDisplayNameAttrKey != after.DingTalkConnectSyncDisplayNameAttrKey {
|
||||
changed = append(changed, "dingtalk_connect_sync_display_name_attr_key")
|
||||
}
|
||||
if before.DingTalkConnectSyncDeptAttrKey != after.DingTalkConnectSyncDeptAttrKey {
|
||||
changed = append(changed, "dingtalk_connect_sync_dept_attr_key")
|
||||
}
|
||||
if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
|
||||
changed = append(changed, "wechat_connect_enabled")
|
||||
}
|
||||
@ -2246,6 +2475,7 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
||||
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
||||
{name: "github", before: before.GitHub, after: after.GitHub},
|
||||
{name: "google", before: before.Google, after: after.Google},
|
||||
{name: "dingtalk", before: before.DingTalk, after: after.DingTalk},
|
||||
}
|
||||
for _, field := range fields {
|
||||
if field.before.Balance != field.after.Balance {
|
||||
@ -2350,6 +2580,11 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
||||
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
|
||||
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
|
||||
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
|
||||
data["auth_source_default_dingtalk_balance"] = authSourceDefaults.DingTalk.Balance
|
||||
data["auth_source_default_dingtalk_concurrency"] = authSourceDefaults.DingTalk.Concurrency
|
||||
data["auth_source_default_dingtalk_subscriptions"] = authSourceDefaults.DingTalk.Subscriptions
|
||||
data["auth_source_default_dingtalk_grant_on_signup"] = authSourceDefaults.DingTalk.GrantOnSignup
|
||||
data["auth_source_default_dingtalk_grant_on_first_bind"] = authSourceDefaults.DingTalk.GrantOnFirstBind
|
||||
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
|
||||
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
|
||||
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
|
||||
@ -3044,3 +3279,56 @@ func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ensureDingTalkSyncAttributes 在保存 settings 后,按 admin 配置的 (attr key, attr name)
|
||||
// 兜底 upsert 对应 user attribute definition:不存在则创建;存在但 name 不同则更新 name
|
||||
// (type/options/required 不变)。仅 internal_only + 对应 sync 开关开启时执行。
|
||||
// 失败仅记录日志,不阻塞 settings 保存。
|
||||
func (h *SettingHandler) ensureDingTalkSyncAttributes(ctx context.Context, settings *service.SystemSettings) {
|
||||
if h.userAttributeService == nil || settings == nil {
|
||||
return
|
||||
}
|
||||
if settings.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
|
||||
return
|
||||
}
|
||||
if settings.DingTalkConnectSyncDisplayName {
|
||||
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDisplayNameAttrKey, settings.DingTalkConnectSyncDisplayNameAttrName, "钉钉 internal_only 登录时同步的钉钉姓名", service.AttributeTypeText)
|
||||
}
|
||||
if settings.DingTalkConnectSyncCorpEmail {
|
||||
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncCorpEmailAttrKey, settings.DingTalkConnectSyncCorpEmailAttrName, "钉钉 internal_only 登录时同步的企业邮箱", service.AttributeTypeEmail)
|
||||
}
|
||||
if settings.DingTalkConnectSyncDept {
|
||||
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDeptAttrKey, settings.DingTalkConnectSyncDeptAttrName, "钉钉 internal_only 登录时同步的完整部门路径(如:公司/研发部)", service.AttributeTypeText)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, name, description string, attrType service.UserAttributeType) {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
existing, err := h.userAttributeService.GetDefinitionByKey(ctx, key)
|
||||
if err == nil && existing != nil {
|
||||
if strings.TrimSpace(name) != "" && existing.Name != name {
|
||||
if _, err := h.userAttributeService.UpdateDefinition(ctx, existing.ID, service.UpdateAttributeDefinitionInput{
|
||||
Name: &name,
|
||||
}); err != nil {
|
||||
slog.Warn("dingtalk: update user attribute definition name failed", "key", key, "err", err.Error())
|
||||
return
|
||||
}
|
||||
slog.Info("dingtalk: updated user attribute definition name", "key", key, "name", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
if _, err := h.userAttributeService.CreateDefinition(ctx, service.CreateAttributeDefinitionInput{
|
||||
Key: key,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Type: attrType,
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
slog.Warn("dingtalk: ensure user attribute definition failed", "key", key, "err", err.Error())
|
||||
return
|
||||
}
|
||||
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
|
||||
}
|
||||
|
||||
@ -137,7 +137,7 @@ func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
@ -174,7 +174,7 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"registration_enabled": true,
|
||||
@ -214,7 +214,7 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -264,7 +264,7 @@ func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodS
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": false,
|
||||
@ -309,7 +309,7 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -388,7 +388,7 @@ func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaul
|
||||
ClockSkewSeconds: 120,
|
||||
},
|
||||
})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -417,7 +417,7 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -450,7 +450,7 @@ func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAu
|
||||
err: errors.New("write auth source defaults failed"),
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"registration_enabled": true,
|
||||
|
||||
319
backend/internal/handler/admin/setting_handler_dingtalk_test.go
Normal file
319
backend/internal/handler/admin/setting_handler_dingtalk_test.go
Normal file
@ -0,0 +1,319 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// dingtalkSettingsRepoStub 复用 settingHandlerRepoStub(已在 setting_handler_auth_source_defaults_test.go 定义)
|
||||
|
||||
func newDingTalkSettingsHandler() (*SettingHandler, *settingHandlerRepoStub) {
|
||||
repo := &settingHandlerRepoStub{values: map[string]string{}}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
return handler, repo
|
||||
}
|
||||
|
||||
// baseValidDingTalkBody 返回一个可以通过所有校验的最小合法 body。
|
||||
func baseValidDingTalkBody() map[string]any {
|
||||
return map[string]any{
|
||||
"dingtalk_connect_enabled": true,
|
||||
"dingtalk_connect_client_id": "test-client-id",
|
||||
"dingtalk_connect_client_secret": "test-client-secret",
|
||||
"dingtalk_connect_redirect_url": "https://example.com/auth/dingtalk/callback",
|
||||
"dingtalk_connect_corp_restriction_policy": "none",
|
||||
}
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A:
|
||||
// internal_only + internal_corp_id="" 应通过校验(→ 200),不再是 400。
|
||||
func TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_internal_corp_id"] = "" // 空值现在合法
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_HappyPath_None 验证 none policy → 200
|
||||
func TestSettingsPUT_DingTalk_HappyPath_None(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "none"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["dingtalk_connect_enabled"])
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID 验证 internal_only + corp_id → 200
|
||||
func TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_internal_corp_id"] = "ding-corp-123"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip 验证 bypass_registration 字段 save+load。
|
||||
// 必须用 policy=internal_only:bypass 仅在该 policy 下生效,其它 policy 写入层会 coerce 为 false。
|
||||
func TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_bypass_registration"] = true
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["dingtalk_connect_bypass_registration"])
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_Disabled_SkipsValidation 验证 disabled 时跳过 corp 校验 → 200。
|
||||
// 用 enabled=true 时必然触发"Client ID is required when enabled"的空 client_id 作为
|
||||
// 哨兵——只要 enabled=false 仍能 200 就证明跳过了。
|
||||
func TestSettingsPUT_DingTalk_Disabled_SkipsValidation(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := map[string]any{
|
||||
"dingtalk_connect_enabled": false,
|
||||
"dingtalk_connect_client_id": "", // 这种空值在 enabled=true 时会被 400 拒绝
|
||||
"dingtalk_connect_corp_restriction_policy": "internal_only",
|
||||
}
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip 验证三个 sync 开关在 internal_only 下可正常 save+load。
|
||||
func TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_sync_corp_email"] = true
|
||||
body["dingtalk_connect_sync_display_name"] = true
|
||||
body["dingtalk_connect_sync_dept"] = true
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["dingtalk_connect_sync_corp_email"], "sync_corp_email should be true for internal_only")
|
||||
require.Equal(t, true, data["dingtalk_connect_sync_display_name"], "sync_display_name should be true for internal_only")
|
||||
require.Equal(t, true, data["dingtalk_connect_sync_dept"], "sync_dept should be true for internal_only")
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse 验证 policy=none 时三个 sync 开关被 coerce 为 false。
|
||||
func TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "none"
|
||||
body["dingtalk_connect_sync_corp_email"] = true
|
||||
body["dingtalk_connect_sync_display_name"] = true
|
||||
body["dingtalk_connect_sync_dept"] = true
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, false, data["dingtalk_connect_sync_corp_email"], "sync_corp_email must be coerced to false when policy=none")
|
||||
require.Equal(t, false, data["dingtalk_connect_sync_display_name"], "sync_display_name must be coerced to false when policy=none")
|
||||
require.Equal(t, false, data["dingtalk_connect_sync_dept"], "sync_dept must be coerced to false when policy=none")
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone 验证升级兼容:
|
||||
// admin 直接把 corp_restriction_policy=whitelist 提交(前端 UI 已无此选项,但 API 仍可命中)
|
||||
// 不应导致 400 失败,应该被静默 coerce 为 none 后通过校验。
|
||||
func TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, repo := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "whitelist"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "none", repo.values[service.SettingKeyDingTalkConnectCorpRestrictionPolicy],
|
||||
"stale whitelist 应在写入路径被 coerce 为 none")
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip 验证 3 个 attr key 字段 save+load + 空值 fallback 到默认值。
|
||||
func TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("custom_attr_keys_saved", func(t *testing.T) {
|
||||
handler, repo := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_sync_corp_email"] = true
|
||||
body["dingtalk_connect_sync_display_name"] = true
|
||||
body["dingtalk_connect_sync_dept"] = true
|
||||
body["dingtalk_connect_sync_corp_email_attr_key"] = "my_email_attr"
|
||||
body["dingtalk_connect_sync_display_name_attr_key"] = "my_name_attr"
|
||||
body["dingtalk_connect_sync_dept_attr_key"] = "my_dept_attr"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 验证写入 DB 的 key
|
||||
require.Equal(t, "my_email_attr", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
|
||||
require.Equal(t, "my_name_attr", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
|
||||
require.Equal(t, "my_dept_attr", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
|
||||
|
||||
// 验证响应中的 attr key
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "my_email_attr", data["dingtalk_connect_sync_corp_email_attr_key"])
|
||||
require.Equal(t, "my_name_attr", data["dingtalk_connect_sync_display_name_attr_key"])
|
||||
require.Equal(t, "my_dept_attr", data["dingtalk_connect_sync_dept_attr_key"])
|
||||
})
|
||||
|
||||
t.Run("empty_attr_keys_fallback_to_defaults", func(t *testing.T) {
|
||||
handler, repo := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
// 不传 attr key → 写入层 fallback 到默认值
|
||||
body["dingtalk_connect_sync_corp_email_attr_key"] = ""
|
||||
body["dingtalk_connect_sync_display_name_attr_key"] = ""
|
||||
body["dingtalk_connect_sync_dept_attr_key"] = ""
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 空值应 fallback 到默认值并持久化
|
||||
require.Equal(t, "dingtalk_email", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
|
||||
require.Equal(t, "dingtalk_name", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
|
||||
require.Equal(t, "dingtalk_department", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
|
||||
})
|
||||
}
|
||||
398
backend/internal/handler/auth_dingtalk_client.go
Normal file
398
backend/internal/handler/auth_dingtalk_client.go
Normal file
@ -0,0 +1,398 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// dingTalkClientConfig 是 DingTalkClient 需要的最小配置子集
|
||||
type dingTalkClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
}
|
||||
|
||||
type DingTalkClient struct {
|
||||
cfg dingTalkClientConfig
|
||||
appToken string
|
||||
appTokenExp time.Time // 钉钉 7200s,留 200s 余量 → 7000s
|
||||
mu sync.Mutex
|
||||
httpClient *http.Client
|
||||
// TODO(multi-instance): Redis 集中缓存 appToken
|
||||
}
|
||||
|
||||
type DingTalkUserTokenResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpireIn int64 `json:"expireIn"`
|
||||
CorpID string `json:"corpId"`
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) ExchangeCodeForUserToken(ctx context.Context, code string) (*DingTalkUserTokenResp, error) {
|
||||
body := map[string]string{
|
||||
"clientId": c.cfg.ClientID,
|
||||
"clientSecret": c.cfg.ClientSecret,
|
||||
"code": code,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
payload, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var out DingTalkUserTokenResp
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(out.AccessToken) == "" {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
type DingTalkAPIError struct {
|
||||
Code string
|
||||
Message string
|
||||
HTTP int
|
||||
}
|
||||
|
||||
func (e *DingTalkAPIError) Error() string {
|
||||
return fmt.Sprintf("dingtalk api error code=%s msg=%s http=%d", e.Code, e.Message, e.HTTP)
|
||||
}
|
||||
|
||||
func parseDingTalkErr(raw []byte, status int) error {
|
||||
var v struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
_ = json.Unmarshal(raw, &v)
|
||||
code := v.Code
|
||||
if code == "" && v.ErrCode != 0 {
|
||||
code = fmt.Sprintf("%d", v.ErrCode)
|
||||
}
|
||||
msg := v.Message
|
||||
if msg == "" {
|
||||
msg = v.ErrMsg
|
||||
}
|
||||
return &DingTalkAPIError{Code: code, Message: msg, HTTP: status}
|
||||
}
|
||||
|
||||
// GetUnionIdByUserToken 调用 /v1.0/contact/users/me 返回 unionId 与用户自设昵称 nick。
|
||||
// nick 来自钉钉新版 OIDC 接口(用户在 App 个人资料填的昵称),与旧版 user/get.nickname 不同源。
|
||||
func (c *DingTalkClient) GetUnionIdByUserToken(ctx context.Context, userToken string) (unionID string, nick string, err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
req.Header.Set("x-acs-dingtalk-access-token", userToken)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
UnionID string `json:"unionId"`
|
||||
Nick string `json:"nick"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if strings.TrimSpace(v.UnionID) == "" {
|
||||
return "", "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return v.UnionID, v.Nick, nil
|
||||
}
|
||||
|
||||
type DingTalkStaffInfo struct {
|
||||
UserID string
|
||||
Name string // 企业内真实姓名(钉钉企业管理后台配置)
|
||||
Nickname string // 钉钉个人昵称(用户自己设置)
|
||||
Email string
|
||||
DeptIDs []int64
|
||||
// CorpID 不来自 staff 接口,来自 userToken;不在此 struct
|
||||
}
|
||||
|
||||
// dingTalkOAPIBase 推导钉钉旧版 OAPI base URL(host: api.dingtalk.com → oapi.dingtalk.com)。
|
||||
// getbyunionid 与 topapi/v2/user/get 仅在旧版 OAPI 提供,不在 v1.0 OpenAPI。
|
||||
func (c *DingTalkClient) dingTalkOAPIBase() string {
|
||||
u, err := url.Parse(c.cfg.UserInfoURL)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "https://oapi.dingtalk.com"
|
||||
}
|
||||
host := u.Host
|
||||
if strings.HasPrefix(host, "api.") {
|
||||
host = "oapi." + strings.TrimPrefix(host, "api.")
|
||||
}
|
||||
return u.Scheme + "://" + host
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetAppToken(ctx context.Context) (string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.appToken != "" && time.Now().Before(c.appTokenExp) {
|
||||
return c.appToken, nil
|
||||
}
|
||||
body := map[string]string{"appKey": c.cfg.ClientID, "appSecret": c.cfg.ClientSecret}
|
||||
payload, _ := json.Marshal(body)
|
||||
// 钉钉新版 v1.0 企业内部应用 access_token: POST /v1.0/oauth2/accessToken
|
||||
// 此 token 也可作为旧版 OAPI 的 access_token 使用(钉钉文档已说明)
|
||||
appTokenURL := strings.Replace(c.cfg.TokenURL, "/oauth2/userAccessToken", "/oauth2/accessToken", 1)
|
||||
if !strings.Contains(appTokenURL, "accessToken") && !strings.Contains(appTokenURL, "gettoken") {
|
||||
appTokenURL = c.cfg.TokenURL // fallback for test stub
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, appTokenURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ExpireIn int64 `json:"expireIn"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if v.AccessToken == "" {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
c.appToken = v.AccessToken
|
||||
ttl := v.ExpireIn
|
||||
if ttl > 200 {
|
||||
ttl -= 200
|
||||
}
|
||||
c.appTokenExp = time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
return c.appToken, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetUserIdByUnionId(ctx context.Context, unionID string) (string, error) {
|
||||
appToken, err := c.GetAppToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
body := map[string]string{"unionid": unionID}
|
||||
payload, _ := json.Marshal(body)
|
||||
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token=XXX
|
||||
// access_token 通过 query string 传递(不是 header)
|
||||
var targetURL string
|
||||
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
|
||||
targetURL = c.dingTalkOAPIBase() + "/topapi/user/getbyunionid?access_token=" + url.QueryEscape(appToken)
|
||||
} else {
|
||||
targetURL = c.cfg.UserInfoURL // fallback for test stub
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
Result struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"result"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if v.ErrCode != 0 {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
if strings.TrimSpace(v.Result.UserID) == "" {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return v.Result.UserID, nil
|
||||
}
|
||||
|
||||
// DingTalkDeptInfo 部门信息(topapi/v2/department/get 返回子集)
|
||||
type DingTalkDeptInfo struct {
|
||||
DeptID int64
|
||||
Name string
|
||||
ParentID int64
|
||||
}
|
||||
|
||||
// GetDeptInfo 查询单个部门信息(用于递归拼部门路径)。
|
||||
// 调用钉钉旧版 OAPI: POST /topapi/v2/department/get?access_token=XXX
|
||||
func (c *DingTalkClient) GetDeptInfo(ctx context.Context, deptID int64) (*DingTalkDeptInfo, error) {
|
||||
appToken, err := c.GetAppToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := map[string]any{"dept_id": deptID, "language": "zh_CN"}
|
||||
payload, _ := json.Marshal(body)
|
||||
var targetURL string
|
||||
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
|
||||
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/department/get?access_token=" + url.QueryEscape(appToken)
|
||||
} else {
|
||||
targetURL = c.cfg.UserInfoURL // test stub fallback
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
Result struct {
|
||||
DeptID int64 `json:"dept_id"`
|
||||
Name string `json:"name"`
|
||||
ParentID int64 `json:"parent_id"`
|
||||
} `json:"result"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.ErrCode != 0 {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return &DingTalkDeptInfo{
|
||||
DeptID: v.Result.DeptID,
|
||||
Name: v.Result.Name,
|
||||
ParentID: v.Result.ParentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetStaffInfoByUserId(ctx context.Context, userID string) (*DingTalkStaffInfo, error) {
|
||||
appToken, err := c.GetAppToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := map[string]string{"userid": userID}
|
||||
payload, _ := json.Marshal(body)
|
||||
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/v2/user/get?access_token=XXX
|
||||
var targetURL string
|
||||
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
|
||||
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/user/get?access_token=" + url.QueryEscape(appToken)
|
||||
} else {
|
||||
targetURL = c.cfg.UserInfoURL // fallback for test stub
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
Result struct {
|
||||
UserID string `json:"userid"`
|
||||
Name string `json:"name"`
|
||||
Nickname string `json:"nickname"`
|
||||
Email string `json:"email"`
|
||||
OrgEmail string `json:"org_email"`
|
||||
Extension string `json:"extension"`
|
||||
DeptID []int64 `json:"dept_id_list"`
|
||||
} `json:"result"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.ErrCode != 0 {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
if strings.TrimSpace(v.Result.UserID) == "" {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
// 邮箱三级 fallback:org_email > email > extension["企业邮箱"](钉钉自定义扩展字段,JSON string)
|
||||
email := strings.TrimSpace(v.Result.OrgEmail)
|
||||
emailSource := "org_email"
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(v.Result.Email)
|
||||
emailSource = "email"
|
||||
}
|
||||
extensionParsed := false
|
||||
if email == "" && strings.TrimSpace(v.Result.Extension) != "" {
|
||||
var ext map[string]string
|
||||
if err := json.Unmarshal([]byte(v.Result.Extension), &ext); err == nil {
|
||||
extensionParsed = true
|
||||
if v, ok := ext["企业邮箱"]; ok {
|
||||
email = strings.TrimSpace(v)
|
||||
emailSource = "extension.企业邮箱"
|
||||
}
|
||||
}
|
||||
}
|
||||
if email == "" {
|
||||
emailSource = "none"
|
||||
}
|
||||
slog.Info("dingtalk staff fetched",
|
||||
"userid", v.Result.UserID,
|
||||
"name_present", v.Result.Name != "",
|
||||
"nickname_present", v.Result.Nickname != "",
|
||||
"name_eq_nickname", v.Result.Name != "" && v.Result.Name == v.Result.Nickname,
|
||||
"email_present", v.Result.Email != "",
|
||||
"org_email_present", v.Result.OrgEmail != "",
|
||||
"extension_present", v.Result.Extension != "",
|
||||
"extension_parsed", extensionParsed,
|
||||
"email_source", emailSource,
|
||||
"dept_count", len(v.Result.DeptID),
|
||||
)
|
||||
return &DingTalkStaffInfo{
|
||||
UserID: v.Result.UserID,
|
||||
Name: v.Result.Name,
|
||||
Nickname: v.Result.Nickname,
|
||||
Email: email,
|
||||
DeptIDs: v.Result.DeptID,
|
||||
}, nil
|
||||
}
|
||||
143
backend/internal/handler/auth_dingtalk_client_test.go
Normal file
143
backend/internal/handler/auth_dingtalk_client_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDingTalkClient_ExchangeCodeForUserToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "POST", r.Method)
|
||||
require.Equal(t, "/v1.0/oauth2/userAccessToken", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"USER_TOKEN_X","expireIn":7200,"refreshToken":"R","corpId":"dingABC"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{
|
||||
ClientID: "k", ClientSecret: "s",
|
||||
TokenURL: server.URL + "/v1.0/oauth2/userAccessToken",
|
||||
},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
resp, err := cli.ExchangeCodeForUserToken(context.Background(), "AUTH_CODE")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "USER_TOKEN_X", resp.AccessToken)
|
||||
require.Equal(t, "dingABC", resp.CorpID)
|
||||
}
|
||||
|
||||
func TestDingTalkClient_GetUnionIdByUserToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "USER_TOKEN_X", r.Header.Get("x-acs-dingtalk-access-token"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"nick":"张三","unionId":"UID_AAA","openId":"OPEN","avatarUrl":"http://x"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/v1.0/contact/users/me"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
unionID, nick, err := cli.GetUnionIdByUserToken(context.Background(), "USER_TOKEN_X")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "UID_AAA", unionID)
|
||||
require.Equal(t, "张三", nick)
|
||||
}
|
||||
|
||||
func TestDingTalkClient_GetAppToken_Cached(t *testing.T) {
|
||||
callCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{ClientID: "k", ClientSecret: "s", TokenURL: server.URL + "/gettoken"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
t1, err := cli.GetAppToken(context.Background())
|
||||
require.NoError(t, err)
|
||||
t2, err := cli.GetAppToken(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, t1, t2)
|
||||
require.Equal(t, 1, callCount, "second call should hit cache")
|
||||
}
|
||||
|
||||
func TestDingTalkClient_GetUserIdByUnionId_60011(t *testing.T) {
|
||||
appTokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
|
||||
}))
|
||||
defer appTokenServer.Close()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"errcode":60011,"errmsg":"not in directory"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{TokenURL: appTokenServer.URL + "/gettoken"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "APP_TKN"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
cli.cfg.UserInfoURL = server.URL + "/v1.0/contact/users/byUnionId"
|
||||
|
||||
_, err := cli.GetUserIdByUnionId(context.Background(), "UID_AAA")
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*DingTalkAPIError)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "60011", apiErr.Code)
|
||||
}
|
||||
|
||||
// TestDingTalkClient_GetDeptInfo_Success 验证 GetDeptInfo 正常情况返回部门信息。
|
||||
func TestDingTalkClient_GetDeptInfo_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"errcode":0,"errmsg":"ok","result":{"dept_id":42,"name":"AI数据","parent_id":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{
|
||||
UserInfoURL: server.URL + "/stub", // 不含 /contact/users/me,走 test stub 路径
|
||||
},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "APP_TKN"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
info, err := cli.GetDeptInfo(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), info.DeptID)
|
||||
require.Equal(t, "AI数据", info.Name)
|
||||
require.Equal(t, int64(1), info.ParentID)
|
||||
}
|
||||
|
||||
// TestDingTalkClient_GetDeptInfo_ErrCode60003 验证 errcode=60003(部门不存在)时返回错误。
|
||||
func TestDingTalkClient_GetDeptInfo_ErrCode60003(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"dept not found"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "APP_TKN"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
_, err := cli.GetDeptInfo(context.Background(), 999)
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*DingTalkAPIError)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "60003", apiErr.Code)
|
||||
}
|
||||
1066
backend/internal/handler/auth_dingtalk_oauth.go
Normal file
1066
backend/internal/handler/auth_dingtalk_oauth.go
Normal file
File diff suppressed because it is too large
Load Diff
391
backend/internal/handler/auth_dingtalk_oauth_test.go
Normal file
391
backend/internal/handler/auth_dingtalk_oauth_test.go
Normal file
@ -0,0 +1,391 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDingTalkOAuthStart_Disabled は sentinel テスト。
|
||||
// TODO(task-1.10): newTestAuthHandlerWithDingTalk helper が追加されたら t.Skip を外す。
|
||||
func TestDingTalkOAuthStart_Disabled(t *testing.T) {
|
||||
t.Skip("helper newTestAuthHandlerWithDingTalk added in Task 1.10; sentinel only")
|
||||
}
|
||||
|
||||
// TestBuildDingTalkSyntheticEmail_UsesUnionID 验证合成邮箱种子使用 unionID。
|
||||
func TestBuildDingTalkSyntheticEmail_UsesUnionID(t *testing.T) {
|
||||
unionID := "union_AbCdEf123"
|
||||
email := buildDingTalkSyntheticEmail(unionID)
|
||||
|
||||
want := "dingtalk-union_abcdef123@dingtalk-connect.invalid"
|
||||
require.Equal(t, want, email)
|
||||
|
||||
// 确保结果都是小写(邮箱大小写不敏感,统一小写)
|
||||
require.True(t, strings.ToLower(email) == email, "synthetic email should be all lowercase")
|
||||
|
||||
// 确保前缀正确
|
||||
require.True(t, strings.HasPrefix(email, "dingtalk-"), "should have dingtalk- prefix")
|
||||
|
||||
// 确保后缀是合成邮箱域名
|
||||
require.True(t, strings.HasSuffix(email, "@dingtalk-connect.invalid"), "should have reserved domain suffix")
|
||||
}
|
||||
|
||||
// TestBuildDingTalkSyntheticEmail_TrimsSpace 验证 unionID 空白被修剪。
|
||||
func TestBuildDingTalkSyntheticEmail_TrimsSpace(t *testing.T) {
|
||||
email := buildDingTalkSyntheticEmail(" UID_XYZ ")
|
||||
require.Equal(t, "dingtalk-uid_xyz@dingtalk-connect.invalid", email)
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_EmptyStaff 验证 staff 为空 struct(跨组织降级路径)时:
|
||||
// - subject 等于 unionID(与 identityKey.ProviderSubject 一致)
|
||||
// - corp_user_id 为空字符串(跨组织时拿不到企业 userid)
|
||||
// - email/username 为空字符串
|
||||
// B/C: Step 3/4 失败降级时 staff = &DingTalkStaffInfo{},claims 不应有 nil。
|
||||
func TestBuildDingTalkUpstreamClaims_EmptyStaff(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "UNION_AAA", "CORP_X")
|
||||
|
||||
require.Equal(t, "", claims["email"])
|
||||
require.Equal(t, "", claims["username"])
|
||||
// 重构后 subject = unionID(与 identityKey.ProviderSubject 保持一致)
|
||||
require.Equal(t, "UNION_AAA", claims["subject"])
|
||||
require.Equal(t, "", claims["corp_user_id"]) // 企业 userid 跨组织时为空
|
||||
require.Equal(t, "UNION_AAA", claims["union_id"])
|
||||
require.Equal(t, "CORP_X", claims["corp_id"])
|
||||
}
|
||||
|
||||
// TestCheckDingTalkCorpAllowed_CrossOrgPolicy 验证 policy=none 时允许任意 corp。
|
||||
// D: corp 校验提前后逻辑不变。
|
||||
func TestCheckDingTalkCorpAllowed_CrossOrgPolicy(t *testing.T) {
|
||||
cfg := config.DingTalkConnectConfig{CorpRestrictionPolicy: "none"}
|
||||
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfg, "dingABC"), "policy=none should allow any corp")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfg, ""), "policy=none should allow empty corp")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfg, "foreign_corp"), "policy=none should allow foreign corp")
|
||||
}
|
||||
|
||||
// TestCheckDingTalkCorpAllowed_InternalOnly 验证 policy=internal_only 时的 corp 校验语义(方案 A 修订)。
|
||||
// 钉钉 userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)不返回 corpId 字段,
|
||||
// 因此 checkDingTalkCorpAllowed 完全不校验 corpID,由 step 3 GetUserIdByUnionId 做真实判定
|
||||
// (跨企业用户会被钉钉错误码 60011/60121 拒绝,mapDingTalkErrorCode 映射回 corp_rejected)。
|
||||
func TestCheckDingTalkCorpAllowed_InternalOnly(t *testing.T) {
|
||||
cfgWithCorpID := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "dingInternal",
|
||||
}
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "dingInternal"), "internal_only: matching corpID allowed")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "foreign_corp"), "internal_only: corpID 字段不再用于决策,step 3 兜底")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, ""), "internal_only: 空 corpID 也通过(钉钉部分授权场景不返回 corpId)")
|
||||
|
||||
cfgNoCorpID := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "",
|
||||
}
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, "dingAnyNonEmpty"), "internal_only + no InternalCorpID: 非空 corpID 通过")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, ""), "internal_only + no InternalCorpID: 空 corpID 也通过")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_PolicyNone 验证 policy=none 时
|
||||
// Step 3/4 失败应降级(shouldFallback=true, isFatal=false)。
|
||||
func TestDecideDingTalkStep34Strategy_PolicyNone(t *testing.T) {
|
||||
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
|
||||
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy("none", step3Err)
|
||||
|
||||
require.True(t, shouldFallback, "policy=none: step3 failure should trigger fallback")
|
||||
require.False(t, isFatal, "policy=none: step3 failure should NOT be fatal")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_PolicyNoneEmpty 验证 policy="" 时行为与 "none" 相同。
|
||||
func TestDecideDingTalkStep34Strategy_PolicyNoneEmpty(t *testing.T) {
|
||||
stepErr := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
|
||||
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy("", stepErr)
|
||||
|
||||
require.True(t, shouldFallback, "policy='': step failure should trigger fallback")
|
||||
require.False(t, isFatal, "policy='': step failure should NOT be fatal")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_PolicyInternalOnly 验证 policy=internal_only 时
|
||||
// Step 3/4 失败应 hard fail(isFatal=true)。
|
||||
func TestDecideDingTalkStep34Strategy_PolicyInternalOnly(t *testing.T) {
|
||||
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
|
||||
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy("internal_only", step3Err)
|
||||
|
||||
require.False(t, shouldFallback, "policy=internal_only: should NOT fallback on step3 error")
|
||||
require.True(t, isFatal, "policy=internal_only: step3 failure should be fatal")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_NoError 验证 stepErr=nil 时两个返回值均为 false。
|
||||
func TestDecideDingTalkStep34Strategy_NoError(t *testing.T) {
|
||||
for _, policy := range []string{"none", "internal_only", ""} {
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy(policy, nil)
|
||||
require.False(t, shouldFallback, "no error should not trigger fallback (policy=%q)", policy)
|
||||
require.False(t, isFatal, "no error should not be fatal (policy=%q)", policy)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart 验证 username 为空时
|
||||
// 退到 email local part(@ 之前的部分)。
|
||||
// E: CompleteDingTalkOAuthRegistration username fallback。
|
||||
func TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
username string
|
||||
wantUser string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "username empty, normal email → local part",
|
||||
email: "dingtalk-uid123@dingtalk-connect.invalid",
|
||||
username: "",
|
||||
wantUser: "dingtalk-uid123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "username already set → keep original",
|
||||
email: "user@example.com",
|
||||
username: "张三",
|
||||
wantUser: "张三",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "username empty, no @ in email → use whole email",
|
||||
email: "noemail",
|
||||
username: "",
|
||||
wantUser: "noemail",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "both empty → invalid",
|
||||
email: "",
|
||||
username: "",
|
||||
wantUser: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
username := tc.username
|
||||
email := tc.email
|
||||
|
||||
// 模拟 CompleteDingTalkOAuthRegistration 中的 fallback 逻辑
|
||||
if username == "" {
|
||||
if at := strings.Index(email, "@"); at > 0 {
|
||||
username = email[:at]
|
||||
} else {
|
||||
username = email
|
||||
}
|
||||
}
|
||||
|
||||
isValid := email != "" && username != ""
|
||||
require.Equal(t, tc.wantUser, username, fmt.Sprintf("username for email=%q", tc.email))
|
||||
require.Equal(t, tc.wantValid, isValid, "validity check")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID 验证重构后 subject = unionID
|
||||
// 而非 staff.UserID,与 identityKey.ProviderSubject 保持一致。
|
||||
// §4.2: buildDingTalkUpstreamClaims subject 字段修正。
|
||||
func TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "user123", Name: "张三", Email: "zhangsan@corp.com"}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "union456", "dingcorp789")
|
||||
|
||||
// 重构后 subject = unionID(全局唯一,与 identityKey.ProviderSubject 一致)
|
||||
require.Equal(t, "union456", claims["subject"], "subject should equal unionID after refactor")
|
||||
// 企业 userid 保留为独立字段,供 audit/debug 使用
|
||||
require.Equal(t, "user123", claims["corp_user_id"], "corp_user_id should be staff.UserID")
|
||||
// union_id 字段与 subject 相同(冗余保留,便于读取)
|
||||
require.Equal(t, "union456", claims["union_id"])
|
||||
require.Equal(t, "dingcorp789", claims["corp_id"])
|
||||
require.Equal(t, "张三", claims["username"])
|
||||
require.Equal(t, "zhangsan@corp.com", claims["email"])
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID 验证跨组织降级时
|
||||
// corp_user_id 为空字符串(跨组织拿不到企业 userid),subject 仍为 unionID。
|
||||
func TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID(t *testing.T) {
|
||||
// 跨组织降级路径:staff = &DingTalkStaffInfo{}(所有字段为零值)
|
||||
staff := &DingTalkStaffInfo{}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "union_cross_org", "foreign_corp")
|
||||
|
||||
require.Equal(t, "union_cross_org", claims["subject"], "subject should still be unionID for cross-org users")
|
||||
require.Equal(t, "", claims["corp_user_id"], "corp_user_id should be empty for cross-org fallback")
|
||||
require.Equal(t, "", claims["email"])
|
||||
require.Equal(t, "", claims["username"])
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims 验证首个 dept_id 被存入 claims。
|
||||
func TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "u1", Name: "张三", Email: "a@b.com", DeptIDs: []int64{42, 99}}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "uid1", "corpX")
|
||||
|
||||
// 只取首个 dept_id
|
||||
require.Equal(t, int64(42), claims["primary_dept_id"], "primary_dept_id should be the first dept_id")
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_NoDeptIDs 验证无部门时 primary_dept_id=0。
|
||||
func TestBuildDingTalkUpstreamClaims_NoDeptIDs(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "u2", Name: "李四"}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "uid2", "corpY")
|
||||
|
||||
require.Equal(t, int64(0), claims["primary_dept_id"], "primary_dept_id should be 0 when no depts")
|
||||
}
|
||||
|
||||
// TestDingTalkStaffFromClaims_RoundTrip 验证 dingTalkStaffFromClaims 能从 claims 恢复 staff 信息。
|
||||
func TestDingTalkStaffFromClaims_RoundTrip(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "u3", Name: "王五", Email: "ww@corp.com", DeptIDs: []int64{55}}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "uid3", "corpZ")
|
||||
|
||||
recovered := dingTalkStaffFromClaims(claims)
|
||||
require.Equal(t, "王五", recovered.Name)
|
||||
require.Equal(t, "ww@corp.com", recovered.Email)
|
||||
require.Equal(t, "u3", recovered.UserID)
|
||||
require.Equal(t, []int64{55}, recovered.DeptIDs)
|
||||
}
|
||||
|
||||
// TestResolveDingTalkDeptPath_SingleLevel 验证单层部门(parent_id=1)返回部门名。
|
||||
func TestResolveDingTalkDeptPath_SingleLevel(t *testing.T) {
|
||||
handler := &AuthHandler{}
|
||||
callCount := 0
|
||||
responses := map[string]string{
|
||||
"42": `{"errcode":0,"result":{"dept_id":42,"name":"研发部","parent_id":1}}`,
|
||||
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
var req struct {
|
||||
DeptID int64 `json:"dept_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if resp, ok := responses[fmt.Sprintf("%d", req.DeptID)]; ok {
|
||||
_, _ = w.Write([]byte(resp))
|
||||
} else {
|
||||
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "tok"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "研发部", path)
|
||||
require.Equal(t, 2, callCount)
|
||||
}
|
||||
|
||||
// TestSyncDingTalkIdentity_UsesCfgAttrKeys 验证 syncDingTalkIdentity 使用 cfg 中配置的 attr key
|
||||
// 而不是硬编码值。通过 userAttributeService=nil 使同步路径走 warn 跳过,但在此之前先验证
|
||||
// syncField 构建逻辑(即 attr key 从 cfg 读取)。
|
||||
// 间接验证:通过构造定制 cfg,确认不同 attr key 可以正确传入(编译时保证类型正确,运行时不 panic)。
|
||||
func TestSyncDingTalkIdentity_UsesCfgAttrKeys_NoopWithNilService(t *testing.T) {
|
||||
handler := &AuthHandler{
|
||||
userAttributeService: nil, // nil → 触发 warn 跳过,但不 panic
|
||||
}
|
||||
|
||||
cfg := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
SyncCorpEmail: true,
|
||||
SyncDisplayName: true,
|
||||
SyncDept: true,
|
||||
// 自定义 attr key(非默认值)
|
||||
SyncCorpEmailAttrKey: "custom_email_key",
|
||||
SyncDisplayNameAttrKey: "custom_name_key",
|
||||
SyncDeptAttrKey: "custom_dept_key",
|
||||
}
|
||||
|
||||
staff := &DingTalkStaffInfo{
|
||||
Name: "张三",
|
||||
Email: "zhangsan@example.com",
|
||||
}
|
||||
|
||||
// 调用不应 panic(userAttributeService 为 nil 时走 warn 跳过路径)
|
||||
require.NotPanics(t, func() {
|
||||
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 42, staff, false)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService 验证 cfg 默认 attr key 为空时
|
||||
// 使用 fallback 默认值(dingtalk_email / dingtalk_name / dingtalk_department)。
|
||||
// 此测试主要验证调用路径不 panic;实际 key 赋值默认值的逻辑在 GetDingTalkConnectOAuthConfig 层。
|
||||
func TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService(t *testing.T) {
|
||||
handler := &AuthHandler{
|
||||
userAttributeService: nil,
|
||||
}
|
||||
|
||||
cfg := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
SyncCorpEmail: true,
|
||||
SyncDisplayName: true,
|
||||
SyncDept: false,
|
||||
// 不设置 attr key(等同于 GetDingTalkConnectOAuthConfig 未设置时 fallback 后的默认值已在调用前填充)
|
||||
SyncCorpEmailAttrKey: "dingtalk_email",
|
||||
SyncDisplayNameAttrKey: "dingtalk_name",
|
||||
SyncDeptAttrKey: "dingtalk_department",
|
||||
}
|
||||
|
||||
staff := &DingTalkStaffInfo{
|
||||
Name: "李四",
|
||||
Email: "lisi@corp.com",
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 99, staff, false)
|
||||
})
|
||||
}
|
||||
|
||||
// TestResolveDingTalkDeptPath_MultiLevel 验证多层部门路径拼接。
|
||||
func TestResolveDingTalkDeptPath_MultiLevel(t *testing.T) {
|
||||
handler := &AuthHandler{}
|
||||
// 模拟:42(AI研发) → parent=10(研发部) → parent=1(根)
|
||||
responses := map[string]string{
|
||||
"42": `{"errcode":0,"result":{"dept_id":42,"name":"AI研发","parent_id":10}}`,
|
||||
"10": `{"errcode":0,"result":{"dept_id":10,"name":"研发部","parent_id":1}}`,
|
||||
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析请求 body 拿到 dept_id
|
||||
var req struct {
|
||||
DeptID int64 `json:"dept_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key := fmt.Sprintf("%d", req.DeptID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if resp, ok := responses[key]; ok {
|
||||
_, _ = w.Write([]byte(resp))
|
||||
} else {
|
||||
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "tok"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "研发部/AI研发", path)
|
||||
}
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@ -18,25 +19,30 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
redeemService *service.RedeemService
|
||||
totpService *service.TotpService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
redeemService *service.RedeemService
|
||||
totpService *service.TotpService
|
||||
userAttributeService *service.UserAttributeService
|
||||
|
||||
dingTalkClientInstance *DingTalkClient
|
||||
dingTalkClientMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService, userAttributeService *service.UserAttributeService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
redeemService: redeemService,
|
||||
totpService: totpService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
redeemService: redeemService,
|
||||
totpService: totpService,
|
||||
userAttributeService: userAttributeService,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -350,7 +350,8 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
|
||||
if email == "" ||
|
||||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -519,7 +520,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "linuxdo")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -195,6 +196,14 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("pending auth session create failed",
|
||||
"intent", strings.TrimSpace(payload.Intent),
|
||||
"provider_type", strings.TrimSpace(payload.Identity.ProviderType),
|
||||
"provider_key", strings.TrimSpace(payload.Identity.ProviderKey),
|
||||
"provider_subject_len", len(strings.TrimSpace(payload.Identity.ProviderSubject)),
|
||||
"resolved_email_len", len(strings.TrimSpace(payload.ResolvedEmail)),
|
||||
"has_target_user", payload.TargetUserID != nil,
|
||||
"error", err.Error())
|
||||
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
|
||||
}
|
||||
|
||||
@ -266,6 +275,22 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
|
||||
}
|
||||
|
||||
// pendingSessionRequiresEmailCompletion 判断 callback 写入的 completion payload 是否处于"补邮箱"状态。
|
||||
// 钉钉跨组织/staff 邮箱缺失时进入此状态:前端跳到补邮箱页,exchange 不应走 adoption apply。
|
||||
func pendingSessionRequiresEmailCompletion(payload map[string]any) bool {
|
||||
if v, ok := payload["requires_email_completion"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "email_completion")
|
||||
}
|
||||
|
||||
// pendingSessionRequiresBindLogin 判断 callback 写入的 completion payload 是否处于"必须绑定已有账户"状态。
|
||||
// 钉钉 signupBlocked=true(注册关 + 钉钉企业豁免关)时进入此状态:前端渲染 bind_login 表单,
|
||||
// exchange 不应消费 session,否则后续 /pending/bind-login 找不到 session。
|
||||
func pendingSessionRequiresBindLogin(payload map[string]any) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required")
|
||||
}
|
||||
|
||||
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
|
||||
if session == nil {
|
||||
return false
|
||||
@ -1467,8 +1492,10 @@ func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]
|
||||
delete(normalized, key)
|
||||
}
|
||||
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
|
||||
// 把多种 choice 别名归一为 oauthPendingChoiceStep;bind_login_required 是独立终态
|
||||
// (前端渲染 needsBindLogin 而非 needsChooser),故不能并入归一化列表。
|
||||
switch step {
|
||||
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
|
||||
case "choice", "choose_account_action", "choose_account", "choose", "email_required":
|
||||
normalized["step"] = oauthPendingChoiceStep
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
|
||||
@ -1594,6 +1621,8 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
// bindPendingOAuthLogin = 绑定已有账户登录,不动 users.username(用户已有自己的名字)
|
||||
h.maybeSyncDingTalkAfterLogin(c.Request.Context(), session, user.ID)
|
||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token pair")
|
||||
@ -1792,6 +1821,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
// createPendingOAuthAccount = 注册新账户,需要把钉钉昵称同步到 users.username 作为初始值
|
||||
h.maybeSyncDingTalkAfterRegistration(c.Request.Context(), session, user.ID)
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
@ -1893,6 +1924,14 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if pendingSessionRequiresEmailCompletion(payload) {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if pendingSessionRequiresBindLogin(payload) {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if !adoptionDecision.hasDecision() {
|
||||
adoptionRequired, _ := payload["adoption_required"].(bool)
|
||||
if adoptionRequired {
|
||||
|
||||
@ -502,7 +502,8 @@ func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string)
|
||||
if email == "" ||
|
||||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -666,7 +667,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "oidc")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "wechat")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
67
backend/internal/handler/dto/account_mapper_redact_test.go
Normal file
67
backend/internal/handler/dto/account_mapper_redact_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestAccountFromServiceShallow_RedactsSensitiveCredentials(t *testing.T) {
|
||||
src := &service.Account{
|
||||
ID: 42,
|
||||
Name: "demo",
|
||||
Platform: "anthropic",
|
||||
Type: "oauth",
|
||||
Credentials: map[string]any{
|
||||
"access_token": "at-secret",
|
||||
"refresh_token": "rt-secret",
|
||||
"id_token": "id-secret",
|
||||
"api_key": "sk-secret",
|
||||
"base_url": "https://api.example.com",
|
||||
"model_mapping": map[string]any{"foo": "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
got := AccountFromServiceShallow(src)
|
||||
require.NotNil(t, got)
|
||||
|
||||
// 敏感键不在 Credentials 里
|
||||
require.NotContains(t, got.Credentials, "access_token")
|
||||
require.NotContains(t, got.Credentials, "refresh_token")
|
||||
require.NotContains(t, got.Credentials, "id_token")
|
||||
require.NotContains(t, got.Credentials, "api_key")
|
||||
// 非敏感键保留
|
||||
require.Equal(t, "https://api.example.com", got.Credentials["base_url"])
|
||||
require.Equal(t, map[string]any{"foo": "bar"}, got.Credentials["model_mapping"])
|
||||
|
||||
// 状态 map 标记敏感键存在
|
||||
require.True(t, got.CredentialsStatus["has_access_token"])
|
||||
require.True(t, got.CredentialsStatus["has_refresh_token"])
|
||||
require.True(t, got.CredentialsStatus["has_id_token"])
|
||||
require.True(t, got.CredentialsStatus["has_api_key"])
|
||||
|
||||
// JSON 序列化校验:响应体里不会出现敏感子串
|
||||
raw, err := json.Marshal(got)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(raw), "rt-secret")
|
||||
require.NotContains(t, string(raw), "at-secret")
|
||||
require.NotContains(t, string(raw), "sk-secret")
|
||||
require.NotContains(t, string(raw), "id-secret")
|
||||
// 状态标识应序列化进 JSON
|
||||
require.Contains(t, string(raw), "credentials_status")
|
||||
require.Contains(t, string(raw), "has_refresh_token")
|
||||
|
||||
// 原始 service.Account 不应被改动
|
||||
require.Equal(t, "rt-secret", src.Credentials["refresh_token"])
|
||||
}
|
||||
|
||||
func TestAccountFromServiceShallow_NilCredentialsOmitsStatus(t *testing.T) {
|
||||
src := &service.Account{ID: 1, Name: "n", Platform: "anthropic", Type: "oauth"}
|
||||
got := AccountFromServiceShallow(src)
|
||||
require.NotNil(t, got)
|
||||
require.Nil(t, got.Credentials)
|
||||
require.Nil(t, got.CredentialsStatus)
|
||||
}
|
||||
44
backend/internal/handler/dto/credentials_redact.go
Normal file
44
backend/internal/handler/dto/credentials_redact.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Package dto provides data transfer objects for HTTP handlers.
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
// RedactCredentials 复制一份 in,剥离 service.SensitiveCredentialKeys 列出的所有敏感子键,
|
||||
// 并产出一个 has_<key> 状态 map 表示哪些敏感键存在且非零值。
|
||||
//
|
||||
// 输入 nil 时返回 nil, nil(避免响应里出现空对象)。
|
||||
// 不修改入参;调用方拿到的 out 可安全序列化进 JSON 返回前端。
|
||||
func RedactCredentials(in map[string]any) (out map[string]any, status map[string]bool) {
|
||||
if in == nil {
|
||||
return nil, nil
|
||||
}
|
||||
out = make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
if service.IsSensitiveCredentialKey(k) {
|
||||
if isCredentialValuePresent(v) {
|
||||
if status == nil {
|
||||
status = make(map[string]bool, 4)
|
||||
}
|
||||
status["has_"+k] = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out[k] = v
|
||||
}
|
||||
return out, status
|
||||
}
|
||||
|
||||
// isCredentialValuePresent 判断值是否"存在且非零"。空字符串、nil、false 均视为未配置;
|
||||
// 其余非零类型(数字、对象、字符串等)视为已配置。
|
||||
func isCredentialValuePresent(v any) bool {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return false
|
||||
case string:
|
||||
return x != ""
|
||||
case bool:
|
||||
return x
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
97
backend/internal/handler/dto/credentials_redact_test.go
Normal file
97
backend/internal/handler/dto/credentials_redact_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedactCredentials_NilInput(t *testing.T) {
|
||||
out, status := RedactCredentials(nil)
|
||||
require.Nil(t, out)
|
||||
require.Nil(t, status)
|
||||
}
|
||||
|
||||
func TestRedactCredentials_StripsSensitiveKeysAndReportsStatus(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"refresh_token": "rt-secret",
|
||||
"access_token": "at-secret",
|
||||
"api_key": "sk-secret",
|
||||
"aws_secret_access_key": "aws-secret",
|
||||
"service_account_json": map[string]any{"private_key": "..."},
|
||||
"private_key": "raw-key",
|
||||
// 非敏感
|
||||
"base_url": "https://api.example.com",
|
||||
"model_mapping": map[string]any{"foo": "bar"},
|
||||
"project_id": "proj-1",
|
||||
"expires_at": int64(123456),
|
||||
}
|
||||
|
||||
out, status := RedactCredentials(in)
|
||||
|
||||
require.NotContains(t, out, "refresh_token")
|
||||
require.NotContains(t, out, "access_token")
|
||||
require.NotContains(t, out, "api_key")
|
||||
require.NotContains(t, out, "aws_secret_access_key")
|
||||
require.NotContains(t, out, "service_account_json")
|
||||
require.NotContains(t, out, "private_key")
|
||||
|
||||
require.Equal(t, "https://api.example.com", out["base_url"])
|
||||
require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
|
||||
require.Equal(t, "proj-1", out["project_id"])
|
||||
require.Equal(t, int64(123456), out["expires_at"])
|
||||
|
||||
require.True(t, status["has_refresh_token"])
|
||||
require.True(t, status["has_access_token"])
|
||||
require.True(t, status["has_api_key"])
|
||||
require.True(t, status["has_aws_secret_access_key"])
|
||||
require.True(t, status["has_service_account_json"])
|
||||
require.True(t, status["has_private_key"])
|
||||
|
||||
// 状态 map 不应携带非敏感键的 has_*
|
||||
require.NotContains(t, status, "has_base_url")
|
||||
require.NotContains(t, status, "has_project_id")
|
||||
}
|
||||
|
||||
func TestRedactCredentials_EmptyValuesNotMarkedPresent(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"refresh_token": "",
|
||||
"access_token": nil,
|
||||
"api_key": false,
|
||||
"id_token": "actual-id",
|
||||
}
|
||||
out, status := RedactCredentials(in)
|
||||
require.Empty(t, out, "敏感键即使为空也不应出现在 redacted output")
|
||||
require.False(t, status["has_refresh_token"])
|
||||
require.False(t, status["has_access_token"])
|
||||
require.False(t, status["has_api_key"])
|
||||
require.True(t, status["has_id_token"])
|
||||
}
|
||||
|
||||
func TestRedactCredentials_DoesNotMutateInput(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"refresh_token": "secret",
|
||||
"base_url": "x",
|
||||
}
|
||||
_, _ = RedactCredentials(in)
|
||||
require.Equal(t, "secret", in["refresh_token"], "原始 map 不应被修改")
|
||||
require.Equal(t, "x", in["base_url"])
|
||||
}
|
||||
|
||||
func TestRedactCredentials_AllKnownSensitiveKeys(t *testing.T) {
|
||||
keys := []string{
|
||||
"access_token", "refresh_token", "id_token",
|
||||
"api_key", "session_key", "cookie",
|
||||
"aws_secret_access_key", "aws_session_token",
|
||||
"service_account_json", "service_account", "private_key",
|
||||
}
|
||||
in := make(map[string]any, len(keys))
|
||||
for _, k := range keys {
|
||||
in[k] = "filled"
|
||||
}
|
||||
out, status := RedactCredentials(in)
|
||||
require.Empty(t, out)
|
||||
for _, k := range keys {
|
||||
require.True(t, status["has_"+k], "key %s 应在 status 中标记为已配置", k)
|
||||
}
|
||||
}
|
||||
@ -198,13 +198,15 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
redactedCreds, credsStatus := RedactCredentials(a.Credentials)
|
||||
out := &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Notes: a.Notes,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Credentials: redactedCreds,
|
||||
CredentialsStatus: credsStatus,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
@ -531,11 +533,15 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
ExpiresAt: rc.ExpiresAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
if rc.IsExpired() {
|
||||
out.Status = service.StatusExpired
|
||||
}
|
||||
|
||||
// For admin_balance/admin_concurrency types, include notes so users can see
|
||||
// why they were charged or credited by admin
|
||||
@ -600,6 +606,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
ImageSize: l.ImageSize,
|
||||
ImageInputSize: l.ImageInputSize,
|
||||
ImageOutputSize: l.ImageOutputSize,
|
||||
ImageSizeSource: l.ImageSizeSource,
|
||||
ImageSizeBreakdown: l.ImageSizeBreakdown,
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
|
||||
@ -148,6 +148,65 @@ func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *
|
||||
require.Equal(t, "claude-3", adminDTO.Model)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesImageBillingMetadataForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
imageSize := "4K"
|
||||
inputSize := "1024x1024"
|
||||
outputSize := "3840x2160"
|
||||
source := "output"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_image_metadata",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 2,
|
||||
ImageSize: &imageSize,
|
||||
ImageInputSize: &inputSize,
|
||||
ImageOutputSize: &outputSize,
|
||||
ImageSizeSource: &source,
|
||||
ImageSizeBreakdown: map[string]int{"4K": 2},
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
for _, got := range []*UsageLog{userDTO, &adminDTO.UsageLog} {
|
||||
require.Equal(t, 2, got.ImageCount)
|
||||
require.NotNil(t, got.ImageSize)
|
||||
require.Equal(t, imageSize, *got.ImageSize)
|
||||
require.NotNil(t, got.ImageInputSize)
|
||||
require.Equal(t, inputSize, *got.ImageInputSize)
|
||||
require.NotNil(t, got.ImageOutputSize)
|
||||
require.Equal(t, outputSize, *got.ImageOutputSize)
|
||||
require.NotNil(t, got.ImageSizeSource)
|
||||
require.Equal(t, source, *got.ImageSizeSource)
|
||||
require.Equal(t, map[string]int{"4K": 2}, got.ImageSizeBreakdown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_PreservesHistoricalMissingImageSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_legacy_image_missing_size",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: nil,
|
||||
}
|
||||
|
||||
dto := UsageLogFromService(log)
|
||||
require.Equal(t, 1, dto.ImageCount)
|
||||
require.Nil(t, dto.ImageSize)
|
||||
require.Nil(t, dto.ImageInputSize)
|
||||
require.Nil(t, dto.ImageOutputSize)
|
||||
require.Nil(t, dto.ImageSizeSource)
|
||||
require.Nil(t, dto.ImageSizeBreakdown)
|
||||
|
||||
body, err := json.Marshal(dto)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), `"image_size":null`)
|
||||
require.NotContains(t, string(body), `"image_size":"2K"`)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@ -56,6 +56,23 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
|
||||
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
|
||||
DingTalkConnectClientSecretConfigured bool `json:"dingtalk_connect_client_secret_configured"`
|
||||
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
|
||||
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
|
||||
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
|
||||
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
|
||||
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
|
||||
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
|
||||
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
|
||||
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
|
||||
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
|
||||
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
|
||||
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
|
||||
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
|
||||
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
|
||||
|
||||
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
|
||||
WeChatConnectAppID string `json:"wechat_connect_app_id"`
|
||||
WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
|
||||
@ -260,6 +277,7 @@ type PublicSettings struct {
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
|
||||
@ -149,25 +149,28 @@ type AdminGroup struct {
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
// Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥
|
||||
// 的存在性通过 CredentialsStatus(has_<key>)暴露,原始值不返回前端。
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
CredentialsStatus map[string]bool `json:"credentials_status,omitempty"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
@ -335,6 +338,7 @@ type RedeemCode struct {
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
@ -400,9 +404,13 @@ type UsageLog struct {
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
MediaType *string `json:"media_type"`
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
ImageInputSize *string `json:"image_input_size"`
|
||||
ImageOutputSize *string `json:"image_output_size"`
|
||||
ImageSizeSource *string `json:"image_size_source"`
|
||||
ImageSizeBreakdown map[string]int `json:"image_size_breakdown"`
|
||||
MediaType *string `json:"media_type"`
|
||||
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -325,6 +326,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
reqLog.Warn("gateway.select_account_no_available",
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64p("group_id", apiKey.GroupID),
|
||||
@ -374,6 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("model", reqModel),
|
||||
@ -566,6 +569,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
reqLog.Warn("gateway.select_account_no_available",
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64p("group_id", currentAPIKey.GroupID),
|
||||
@ -626,6 +630,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("model", reqModel),
|
||||
@ -946,8 +951,8 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
platform = forcedPlatform
|
||||
}
|
||||
|
||||
// Get available models from account configurations (without platform filter)
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||
// Get available models from account configurations for the selected group platform.
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
|
||||
|
||||
if len(availableModels) > 0 {
|
||||
// Build model list from whitelist
|
||||
@ -968,7 +973,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Fallback to default models
|
||||
if platform == "openai" {
|
||||
if platform == service.PlatformOpenAI {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": openai.DefaultModels,
|
||||
@ -976,6 +981,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": geminicli.DefaultModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": claude.DefaultModels,
|
||||
@ -1312,6 +1325,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
@ -1542,6 +1560,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
@ -161,14 +161,26 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
groupPlatform := ""
|
||||
if apiKey.Group != nil {
|
||||
groupPlatform = apiKey.Group.Platform
|
||||
}
|
||||
selectionSessionHash := sessionHash
|
||||
if groupPlatform == service.PlatformGemini && selectionSessionHash != "" {
|
||||
selectionSessionHash = "gemini:" + selectionSessionHash
|
||||
}
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
if groupPlatform == service.PlatformGemini {
|
||||
fs = NewFailoverState(h.maxAccountSwitchesGemini, false)
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -194,6 +206,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
@ -213,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
fs.FailedAccountIDs[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformGemini {
|
||||
if h.geminiCompatService == nil {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured")
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
return
|
||||
}
|
||||
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
|
||||
} else {
|
||||
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
}
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@ -302,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
|
||||
return
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
|
||||
@ -174,6 +174,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -199,6 +200,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
@ -308,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
|
||||
return
|
||||
}
|
||||
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
|
||||
136
backend/internal/handler/gateway_models_test.go
Normal file
136
backend/internal/handler/gateway_models_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type gatewayModelsAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
|
||||
byGroup map[int64][]service.Account
|
||||
}
|
||||
|
||||
type gatewayModelsResponseForTest struct {
|
||||
Object string `json:"object"`
|
||||
Data []gatewayModelItemForTest `json:"data"`
|
||||
}
|
||||
|
||||
type gatewayModelItemForTest struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
accounts, ok := s.byGroup[groupID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
out := make([]service.Account, len(accounts))
|
||||
copy(out, accounts)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
|
||||
return &GatewayHandler{
|
||||
gatewayService: service.NewGatewayService(
|
||||
repo,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(20)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{ID: 1, Platform: service.PlatformGemini},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, "list", got.Object)
|
||||
require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
|
||||
require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
|
||||
}
|
||||
|
||||
func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(21)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: service.PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func modelIDsForTest(models []gatewayModelItemForTest) []string {
|
||||
ids := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
@ -61,6 +61,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -113,6 +114,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -372,6 +374,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -419,6 +422,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
|
||||
@ -143,6 +143,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
} else {
|
||||
@ -155,6 +156,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -176,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@ -201,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
@ -292,7 +299,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
|
||||
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
|
||||
// has been probed to not support the Responses API, the request is
|
||||
// is forced or probed to not support the Responses API, the request is
|
||||
// forwarded directly to /v1/chat/completions — not through the default
|
||||
// CC→Responses conversion path.
|
||||
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
|
||||
|
||||
@ -282,6 +282,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
|
||||
return
|
||||
@ -297,6 +298,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -330,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@ -354,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
@ -677,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if err != nil {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -690,6 +698,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -992,6 +1001,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
@ -1002,6 +1012,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||||
}
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
@ -1598,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
|
||||
@ -157,6 +157,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -168,6 +169,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
@ -22,10 +23,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
opsModelKey = "ops_model"
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
opsModelKey = "ops_model"
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
opsRoutingCapacityLimitedKey = "ops_routing_capacity_limited"
|
||||
|
||||
opsUpstreamModelKey = "ops_upstream_model"
|
||||
opsRequestTypeKey = "ops_request_type"
|
||||
@ -45,6 +47,8 @@ const (
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
opsCodeInvalidAPIKey = "INVALID_API_KEY"
|
||||
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -393,6 +397,42 @@ func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string)
|
||||
}
|
||||
}
|
||||
|
||||
func markOpsRoutingCapacityLimited(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.Set(opsRoutingCapacityLimitedKey, true)
|
||||
}
|
||||
|
||||
func markOpsRoutingCapacityLimitedIfNoAvailable(c *gin.Context, err error) {
|
||||
if !isOpsNoAvailableAccountError(err) {
|
||||
return
|
||||
}
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
}
|
||||
|
||||
func isOpsRoutingCapacityLimited(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
v, ok := c.Get(opsRoutingCapacityLimitedKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
marked, _ := v.(bool)
|
||||
return marked
|
||||
}
|
||||
|
||||
func isOpsNoAvailableAccountError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, service.ErrNoAvailableAccounts) || errors.Is(err, service.ErrNoAvailableCompactAccounts) {
|
||||
return true
|
||||
}
|
||||
return isOpsNoAvailableAccountMessage(err.Error())
|
||||
}
|
||||
|
||||
type opsCaptureWriter struct {
|
||||
gin.ResponseWriter
|
||||
limit int
|
||||
@ -775,11 +815,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
|
||||
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
|
||||
|
||||
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
|
||||
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
|
||||
|
||||
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
|
||||
errorSource := classifyOpsErrorSource(phase, parsed.Message)
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, normalizedType, parsed.Message, parsed.Code, status)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{
|
||||
RequestID: requestID,
|
||||
@ -1114,6 +1150,9 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
msg := strings.ToLower(message)
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
if isOpsClientAuthError(code, msg) {
|
||||
return "auth"
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
@ -1134,7 +1173,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
if isOpsNoAvailableAccountMessage(msg) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@ -1178,7 +1217,31 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status int) (phase string, isBusinessLimited bool, errorOwner string, errorSource string) {
|
||||
phase = classifyOpsPhase(errType, message, code)
|
||||
routingCapacityLimited := isOpsRoutingCapacityLimited(c)
|
||||
clientBusinessLimited := service.HasOpsClientBusinessLimited(c)
|
||||
upstreamError := hasOpsUpstreamErrorContext(c)
|
||||
if upstreamError && !routingCapacityLimited {
|
||||
phase = "upstream"
|
||||
}
|
||||
if clientBusinessLimited && !upstreamError && !routingCapacityLimited {
|
||||
phase = "auth"
|
||||
}
|
||||
if routingCapacityLimited {
|
||||
phase = "routing"
|
||||
}
|
||||
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
|
||||
isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
|
||||
errorOwner = classifyOpsErrorOwner(phase, message)
|
||||
errorSource = classifyOpsErrorSource(phase, message)
|
||||
return phase, isBusinessLimited, errorOwner, errorSource
|
||||
}
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string, localClientAuthError ...bool) bool {
|
||||
if len(localClientAuthError) > 0 && localClientAuthError[0] {
|
||||
return true
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
@ -1195,6 +1258,47 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpsClientAuthError(code string, msg string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
|
||||
}
|
||||
|
||||
func hasOpsUpstreamErrorContext(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
|
||||
switch code := v.(type) {
|
||||
case int:
|
||||
if code > 0 {
|
||||
return true
|
||||
}
|
||||
case int64:
|
||||
if code > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
|
||||
if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpsNoAvailableAccountMessage(message string) bool {
|
||||
msg := strings.ToLower(message)
|
||||
return strings.Contains(msg, opsErrNoAvailableAccounts) ||
|
||||
strings.Contains(msg, "no available account") ||
|
||||
strings.Contains(msg, "no available gemini accounts") ||
|
||||
strings.Contains(msg, "no available openai accounts") ||
|
||||
strings.Contains(msg, "no available compatible accounts")
|
||||
}
|
||||
|
||||
func classifyOpsErrorOwner(phase string, message string) string {
|
||||
// Standardized owners: client|provider|platform
|
||||
switch phase {
|
||||
|
||||
@ -275,6 +275,218 @@ func TestNormalizeOpsErrorType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsNoAvailableAccountsExcludedFromSLA(t *testing.T) {
|
||||
const message = "No available accounts"
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "routing", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsRoutingCapacityMarkerExcludesMaskedSelectionFailureFromSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"Service temporarily unavailable",
|
||||
"",
|
||||
http.StatusServiceUnavailable,
|
||||
)
|
||||
|
||||
require.Equal(t, "routing", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
code string
|
||||
status int
|
||||
}{
|
||||
{
|
||||
name: "standard invalid API key",
|
||||
errType: "api_error",
|
||||
message: "Invalid API key",
|
||||
code: "INVALID_API_KEY",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "standard missing API key",
|
||||
errType: "api_error",
|
||||
message: "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header",
|
||||
code: "API_KEY_REQUIRED",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google invalid API key",
|
||||
errType: "api_error",
|
||||
message: "Invalid API key",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google missing API key",
|
||||
errType: "api_error",
|
||||
message: "API key is required",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
errType := normalizeOpsErrorType(tt.errType, tt.code)
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "auth", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "client", errorOwner)
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "ACCESS_DENIED")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Access denied", "ACCESS_DENIED", http.StatusForbidden)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "auth", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "client", errorOwner)
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "INTERNAL_ERROR")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Failed to validate API key", "INTERNAL_ERROR", http.StatusInternalServerError)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "internal", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsUnsupportedModelExcludedFromSLA(t *testing.T) {
|
||||
tests := []string{
|
||||
"No available accounts: no available accounts supporting model: made-up-model",
|
||||
"No available accounts: no available OpenAI accounts supporting model: made-up-model",
|
||||
"No available Gemini accounts: no available Gemini accounts supporting model: made-up-model",
|
||||
"No available accounts: no available accounts supporting model: made-up-model (channel pricing restriction)",
|
||||
}
|
||||
|
||||
for _, message := range tests {
|
||||
t.Run(message, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "routing", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"No available accounts",
|
||||
"",
|
||||
http.StatusServiceUnavailable,
|
||||
)
|
||||
|
||||
require.Equal(t, "routing", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"Invalid API key",
|
||||
"401",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
require.Equal(t, "upstream", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "provider", errorOwner)
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.SetOpsUpstreamError(c, http.StatusServiceUnavailable, "No available accounts", "")
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"No available accounts",
|
||||
"",
|
||||
http.StatusServiceUnavailable,
|
||||
)
|
||||
|
||||
require.Equal(t, "upstream", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "provider", errorOwner)
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -58,7 +58,7 @@ func TestResolvePageImagePath(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("expected direct image path to be accepted")
|
||||
}
|
||||
want := filepath.Join(base, "logo.png")
|
||||
want := mustEvalSymlinks(t, filepath.Join(base, "logo.png"))
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
@ -67,7 +67,7 @@ func TestResolvePageImagePath(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("expected nested image path to be accepted")
|
||||
}
|
||||
want = filepath.Join(base, "images", "logo.png")
|
||||
want = mustEvalSymlinks(t, filepath.Join(base, "images", "logo.png"))
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
@ -100,3 +100,13 @@ func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
|
||||
t.Fatalf("expected symlink escape to be rejected, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func mustEvalSymlinks(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
|
||||
realPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
t.Fatalf("eval symlinks for %q: %v", path, err)
|
||||
}
|
||||
return realPath
|
||||
}
|
||||
|
||||
@ -61,6 +61,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
TablePageSizeOptions: settings.TablePageSizeOptions,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
|
||||
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
|
||||
|
||||
@ -67,6 +67,7 @@ type userProfileResponse struct {
|
||||
LinuxDoBound bool `json:"linuxdo_bound"`
|
||||
OIDCBound bool `json:"oidc_bound"`
|
||||
WeChatBound bool `json:"wechat_bound"`
|
||||
DingTalkBound bool `json:"dingtalk_bound"`
|
||||
}
|
||||
|
||||
type userProfileSourceContext struct {
|
||||
@ -528,15 +529,17 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
|
||||
LinuxDoBound: identities.LinuxDo.Bound,
|
||||
OIDCBound: identities.OIDC.Bound,
|
||||
WeChatBound: identities.WeChat.Bound,
|
||||
DingTalkBound: identities.DingTalk.Bound,
|
||||
}
|
||||
}
|
||||
|
||||
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
|
||||
return map[string]service.UserIdentitySummary{
|
||||
"email": identities.Email,
|
||||
"linuxdo": identities.LinuxDo,
|
||||
"oidc": identities.OIDC,
|
||||
"wechat": identities.WeChat,
|
||||
"email": identities.Email,
|
||||
"linuxdo": identities.LinuxDo,
|
||||
"oidc": identities.OIDC,
|
||||
"wechat": identities.WeChat,
|
||||
"dingtalk": identities.DingTalk,
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,7 +588,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
|
||||
|
||||
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
|
||||
out := make([]service.UserIdentitySummary, 0, 3)
|
||||
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
|
||||
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat, identities.DingTalk} {
|
||||
if summary.Bound {
|
||||
out = append(out, summary)
|
||||
}
|
||||
|
||||
@ -105,10 +105,16 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
|
||||
|
||||
// CreatePayment creates an Alipay payment using the following routing:
|
||||
// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
|
||||
// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
|
||||
// - Desktop fallback: if precreate is unavailable for the merchant, fall back
|
||||
// to alipay.trade.page.pay and expose both pay_url and qr_code so the
|
||||
// frontend can render a QR while still allowing direct page open.
|
||||
// - Desktop, default: prefer alipay.trade.precreate (FACE_TO_FACE_PAYMENT) to
|
||||
// get a scannable QR payload. If precreate is unavailable for the merchant,
|
||||
// fall back to alipay.trade.page.pay and expose pay_url only — the frontend
|
||||
// opens the Alipay checkout in a new tab.
|
||||
// - Desktop, paymentMode == "redirect": skip precreate and go straight to
|
||||
// alipay.trade.page.pay so the frontend always opens the Alipay checkout
|
||||
// in a new tab. Use this when the merchant has not enabled FACE_TO_FACE_PAYMENT.
|
||||
//
|
||||
// Note: alipay.trade.page.pay returns a checkout page URL, not a scannable
|
||||
// payment QR. Never expose it via the QRCode field.
|
||||
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
@ -150,6 +156,13 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
|
||||
}
|
||||
|
||||
func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
// Explicit redirect mode: merchant opted into "always open the Alipay
|
||||
// checkout page in a new tab" via the provider instance's payment_mode.
|
||||
// Skip precreate to avoid a wasted API call.
|
||||
if strings.EqualFold(strings.TrimSpace(a.config["paymentMode"]), "redirect") {
|
||||
return a.createPagePayTrade(client, req, notifyURL, returnURL)
|
||||
}
|
||||
|
||||
resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
|
||||
if precreateErr == nil {
|
||||
return resp, nil
|
||||
@ -204,10 +217,12 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
|
||||
}
|
||||
// Only PayURL is exposed: alipay.trade.page.pay returns a checkout page URL
|
||||
// that must be opened in a browser, not a scannable payment QR. Setting it
|
||||
// as QRCode would let the frontend render an unscannable image.
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
QRCode: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@ -189,8 +189,63 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
|
||||
if resp.PayURL == "" {
|
||||
t.Fatal("expected pay_url for desktop page pay")
|
||||
}
|
||||
if resp.QRCode != resp.PayURL {
|
||||
t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
|
||||
// page.pay returns a checkout page URL, not a scannable QR payload —
|
||||
// it must never be exposed via QRCode (the frontend would render an
|
||||
// unscannable image from it).
|
||||
if resp.QRCode != "" {
|
||||
t.Fatalf("qr_code = %q, want empty for page pay", resp.QRCode)
|
||||
}
|
||||
}
|
||||
|
||||
// When the provider instance is configured with paymentMode == "redirect",
|
||||
// the desktop flow must skip precreate and go straight to page.pay.
|
||||
func TestCreateTradeRedirectModeSkipsPrecreate(t *testing.T) {
|
||||
origPreCreate := alipayTradePreCreate
|
||||
origPagePay := alipayTradePagePay
|
||||
t.Cleanup(func() {
|
||||
alipayTradePreCreate = origPreCreate
|
||||
alipayTradePagePay = origPagePay
|
||||
})
|
||||
|
||||
preCreateCalls := 0
|
||||
pagePayCalls := 0
|
||||
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
|
||||
preCreateCalls++
|
||||
return &alipay.TradePreCreateRsp{
|
||||
Error: alipay.Error{Code: alipay.CodeSuccess},
|
||||
QRCode: "https://qr.alipay.example.com/precreate-token",
|
||||
}, nil
|
||||
}
|
||||
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
|
||||
pagePayCalls++
|
||||
if param.ProductCode != alipayProductCodePagePay {
|
||||
t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePagePay)
|
||||
}
|
||||
return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
|
||||
}
|
||||
|
||||
provider := &Alipay{
|
||||
config: map[string]string{"paymentMode": "redirect"},
|
||||
}
|
||||
resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
|
||||
OrderID: "sub2_103",
|
||||
Amount: "12.00",
|
||||
Subject: "Balance recharge",
|
||||
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if preCreateCalls != 0 {
|
||||
t.Fatalf("precreate calls = %d, want 0 (redirect mode must skip precreate)", preCreateCalls)
|
||||
}
|
||||
if pagePayCalls != 1 {
|
||||
t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
|
||||
}
|
||||
if resp.PayURL == "" {
|
||||
t.Fatal("expected pay_url for redirect mode")
|
||||
}
|
||||
if resp.QRCode != "" {
|
||||
t.Fatalf("qr_code = %q, want empty for redirect mode", resp.QRCode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -254,6 +254,8 @@ const (
|
||||
proxyTLSHandshakeTimeout = 5 * time.Second
|
||||
// clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body)
|
||||
clientTimeout = 10 * time.Second
|
||||
// fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use.
|
||||
fetchAvailableModelsBodyLimit int64 = 8 << 20
|
||||
)
|
||||
|
||||
func NewClient(proxyURL string) (*Client, error) {
|
||||
@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct {
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, nil, errors.New("antigravity client is not configured")
|
||||
}
|
||||
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
fetchClient := c.fetchAvailableModelsHTTPClient()
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||
@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
resp, err := fetchClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1))
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit {
|
||||
return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
func (c *Client) fetchAvailableModelsHTTPClient() *http.Client {
|
||||
fetchClient := *c.httpClient
|
||||
fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect
|
||||
return &fetchClient
|
||||
}
|
||||
|
||||
func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
if req == nil || req.URL == nil {
|
||||
return errors.New("redirect url is nil")
|
||||
}
|
||||
if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) {
|
||||
return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllowedFetchAvailableModelsRedirectHost(host string) bool {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
for _, baseURL := range BaseURLs {
|
||||
parsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(host, parsed.Hostname()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ── Privacy API ──────────────────────────────────────────────────────
|
||||
|
||||
// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)
|
||||
|
||||
@ -225,6 +225,41 @@ func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testi
|
||||
assert.Equal(t, "Describe this", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_EmptyContentNeverNull(t *testing.T) {
|
||||
// Regression for #2515: the upstream Responses API rejects an input item
|
||||
// whose content field is JSON null. Any chat-completions message that
|
||||
// yields no usable content parts must serialize content as a string.
|
||||
cases := []struct {
|
||||
name string
|
||||
content json.RawMessage
|
||||
}{
|
||||
{"null content", json.RawMessage(`null`)},
|
||||
{"empty array content", json.RawMessage(`[]`)},
|
||||
{"only empty text part", json.RawMessage(`[{"type":"text","text":""}]`)},
|
||||
{"only empty base64 image part", json.RawMessage(`[{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`)},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-5.5",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: tc.content},
|
||||
},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, string(resp.Input), `"content":null`,
|
||||
"converted input must not contain a null content field")
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, `""`, string(items[0].Content),
|
||||
"content must be an empty string, not null")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
|
||||
@ -339,7 +339,14 @@ func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error
|
||||
if content.Text != nil {
|
||||
return json.Marshal(*content.Text)
|
||||
}
|
||||
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
|
||||
parts := convertChatContentPartsToResponses(content.Parts)
|
||||
if len(parts) == 0 {
|
||||
// A nil slice marshals to JSON null, which the upstream Responses API
|
||||
// rejects ("expected an array of objects or string, but got null").
|
||||
// Fall back to an empty string when no usable parts remain.
|
||||
return json.Marshal("")
|
||||
}
|
||||
return json.Marshal(parts)
|
||||
}
|
||||
|
||||
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {
|
||||
|
||||
@ -306,6 +306,37 @@ type ResponsesUsage struct {
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
|
||||
type responsesUsageAlias ResponsesUsage
|
||||
var aux struct {
|
||||
responsesUsageAlias
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
PromptTokensDetails *ResponsesInputTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *ResponsesOutputTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
*u = ResponsesUsage(aux.responsesUsageAlias)
|
||||
if u.InputTokens == 0 && aux.PromptTokens != 0 {
|
||||
u.InputTokens = aux.PromptTokens
|
||||
}
|
||||
if u.OutputTokens == 0 && aux.CompletionTokens != 0 {
|
||||
u.OutputTokens = aux.CompletionTokens
|
||||
}
|
||||
if u.InputTokensDetails == nil && aux.PromptTokensDetails != nil {
|
||||
u.InputTokensDetails = aux.PromptTokensDetails
|
||||
}
|
||||
if u.OutputTokensDetails == nil && aux.CompletionTokensDetails != nil {
|
||||
u.OutputTokensDetails = aux.CompletionTokensDetails
|
||||
}
|
||||
if u.TotalTokens == 0 && (u.InputTokens != 0 || u.OutputTokens != 0) {
|
||||
u.TotalTokens = u.InputTokens + u.OutputTokens
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails breaks down input token usage.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems)
|
||||
package openai_compat
|
||||
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的有效支持状态。
|
||||
//
|
||||
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
|
||||
type AccountResponsesSupport int
|
||||
@ -35,11 +35,43 @@ const (
|
||||
ResponsesSupportNo
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
|
||||
// ResponsesSupportMode 描述账号级 Responses API 路由覆盖模式。
|
||||
type ResponsesSupportMode string
|
||||
|
||||
const (
|
||||
// ResponsesSupportModeAuto 表示跟随自动探测结果。
|
||||
ResponsesSupportModeAuto ResponsesSupportMode = "auto"
|
||||
|
||||
// ResponsesSupportModeForceResponses 强制使用 /v1/responses。
|
||||
ResponsesSupportModeForceResponses ResponsesSupportMode = "force_responses"
|
||||
|
||||
// ResponsesSupportModeForceChatCompletions 强制使用 /v1/chat/completions。
|
||||
ResponsesSupportModeForceChatCompletions ResponsesSupportMode = "force_chat_completions"
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesMode 是 accounts.extra JSON 中存储手动覆盖模式的键名。
|
||||
// 值类型为 string:auto=跟随探测,force_responses=强制 Responses,
|
||||
// force_chat_completions=强制 Chat Completions。
|
||||
const ExtraKeyResponsesMode = "openai_responses_mode"
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储自动探测结果的键名。
|
||||
// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。
|
||||
const ExtraKeyResponsesSupported = "openai_responses_supported"
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
|
||||
// NormalizeResponsesSupportMode 归一化账号级 Responses API 路由覆盖模式。
|
||||
// 缺失或非法值按 auto 处理,以保持存量行为。
|
||||
func NormalizeResponsesSupportMode(mode string) ResponsesSupportMode {
|
||||
switch ResponsesSupportMode(mode) {
|
||||
case ResponsesSupportModeForceResponses:
|
||||
return ResponsesSupportModeForceResponses
|
||||
case ResponsesSupportModeForceChatCompletions:
|
||||
return ResponsesSupportModeForceChatCompletions
|
||||
default:
|
||||
return ResponsesSupportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取手动覆盖模式与探测标记。
|
||||
//
|
||||
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
|
||||
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。
|
||||
@ -47,6 +79,14 @@ func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
|
||||
if extra == nil {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
if mode, ok := extra[ExtraKeyResponsesMode].(string); ok {
|
||||
switch NormalizeResponsesSupportMode(mode) {
|
||||
case ResponsesSupportModeForceResponses:
|
||||
return ResponsesSupportYes
|
||||
case ResponsesSupportModeForceChatCompletions:
|
||||
return ResponsesSupportNo
|
||||
}
|
||||
}
|
||||
v, ok := extra[ExtraKeyResponsesSupported]
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
|
||||
@ -16,6 +16,12 @@ func TestResolveResponsesSupport(t *testing.T) {
|
||||
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
|
||||
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
|
||||
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
|
||||
{"force responses", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses)}, ResponsesSupportYes},
|
||||
{"force chat completions", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions)}, ResponsesSupportNo},
|
||||
{"auto follows probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeAuto), ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
|
||||
{"invalid mode follows probe", map[string]any{ExtraKeyResponsesMode: "bogus", ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
|
||||
{"force responses overrides probe false", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, ResponsesSupportYes},
|
||||
{"force chat completions overrides probe true", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, ResponsesSupportNo},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@ -42,6 +48,10 @@ func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
// 已探测:标记决定
|
||||
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
|
||||
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
|
||||
|
||||
// 手动覆盖:覆盖自动探测结果
|
||||
{"force responses overrides unsupported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, true},
|
||||
{"force chat completions overrides supported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@ -53,3 +63,26 @@ func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesSupportMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
want ResponsesSupportMode
|
||||
}{
|
||||
{"empty", "", ResponsesSupportModeAuto},
|
||||
{"auto", "auto", ResponsesSupportModeAuto},
|
||||
{"force responses", "force_responses", ResponsesSupportModeForceResponses},
|
||||
{"force chat completions", "force_chat_completions", ResponsesSupportModeForceChatCompletions},
|
||||
{"invalid", "enabled", ResponsesSupportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := NormalizeResponsesSupportMode(tc.mode)
|
||||
if got != tc.want {
|
||||
t.Errorf("NormalizeResponsesSupportMode(%q) = %q, want %q", tc.mode, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,6 +230,20 @@ type UserDashboardStats struct {
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
|
||||
// 按"有效平台"维度拆分(与 ops 路径口径一致:group.platform 优先,否则 account.platform)
|
||||
ByPlatform []PlatformDashboardStats `json:"by_platform,omitempty"`
|
||||
}
|
||||
|
||||
// PlatformDashboardStats 单个平台的用量明细。
|
||||
type PlatformDashboardStats struct {
|
||||
Platform string `json:"platform"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
@ -265,13 +279,22 @@ type UsageStats struct {
|
||||
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
// PlatformUsage 表示某用户/某 API key 在单个"有效平台"维度的用量明细。
|
||||
// Platform 取值与 ops 路径口径一致:优先 groups.platform,否则 accounts.platform。
|
||||
type PlatformUsage struct {
|
||||
Platform string `json:"platform"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
ByPlatform []PlatformUsage `json:"by_platform,omitempty"`
|
||||
}
|
||||
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
|
||||
@ -12,3 +12,14 @@ func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRel
|
||||
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldEnqueueSchedulerOutboxForExtraUpdates_OpenAIResponsesCapabilityKeysAreRelevant(t *testing.T) {
|
||||
updates := map[string]any{
|
||||
"openai_responses_mode": "force_chat_completions",
|
||||
"openai_responses_supported": false,
|
||||
}
|
||||
|
||||
if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
|
||||
t.Fatalf("expected responses capability updates to enqueue scheduler outbox")
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,7 +204,8 @@ func (r *announcementRepository) ListActive(ctx context.Context, now time.Time)
|
||||
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
|
||||
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
|
||||
).
|
||||
Order(dbent.Desc(announcement.FieldID))
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
Limit(200)
|
||||
|
||||
items, err := q.All(ctx)
|
||||
if err != nil {
|
||||
|
||||
@ -283,47 +283,90 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
}
|
||||
|
||||
func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
groups, err := q.
|
||||
// 第一步:只查 ID + sort_order(轻量,不做分页 — 需要全量排序 account_count)。
|
||||
rows, err := q.Clone().
|
||||
Select(group.FieldID, group.FieldSortOrder).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
type sortEntry struct {
|
||||
id int64
|
||||
sortOrder int
|
||||
accountCount int64
|
||||
}
|
||||
entries := make([]sortEntry, 0, len(rows))
|
||||
groupIDs := make([]int64, len(rows))
|
||||
for i, r := range rows {
|
||||
groupIDs[i] = r.ID
|
||||
entries = append(entries, sortEntry{id: r.ID, sortOrder: r.SortOrder})
|
||||
}
|
||||
|
||||
// 第二步:批量加载 account counts(一次 SQL)。
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range outGroups {
|
||||
c := counts[outGroups[i].ID]
|
||||
outGroups[i].AccountCount = c.Total
|
||||
outGroups[i].ActiveAccountCount = c.Active
|
||||
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||
for i := range entries {
|
||||
c := counts[entries[i].id]
|
||||
if c.Total > 0 {
|
||||
entries[i].accountCount = c.Total
|
||||
}
|
||||
}
|
||||
|
||||
// 第三步:Go 侧排序(数据量 = Group 总数,通常 < 200,安全)。
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
sort.SliceStable(outGroups, func(i, j int) bool {
|
||||
if outGroups[i].AccountCount == outGroups[j].AccountCount {
|
||||
if outGroups[i].SortOrder == outGroups[j].SortOrder {
|
||||
return outGroups[i].ID < outGroups[j].ID
|
||||
}
|
||||
return outGroups[i].SortOrder < outGroups[j].SortOrder
|
||||
tieCmp := func(a, b sortEntry) bool {
|
||||
if a.sortOrder == b.sortOrder {
|
||||
return a.id < b.id
|
||||
}
|
||||
return a.sortOrder < b.sortOrder
|
||||
}
|
||||
sort.SliceStable(entries, func(i, j int) bool {
|
||||
if entries[i].accountCount == entries[j].accountCount {
|
||||
return tieCmp(entries[i], entries[j])
|
||||
}
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return outGroups[i].AccountCount < outGroups[j].AccountCount
|
||||
return entries[i].accountCount < entries[j].accountCount
|
||||
}
|
||||
return outGroups[i].AccountCount > outGroups[j].AccountCount
|
||||
return entries[i].accountCount > entries[j].accountCount
|
||||
})
|
||||
|
||||
return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
|
||||
// 第四步:分页,只加载当前页需要的完整 Group。
|
||||
page := paginateSlice(entries, params)
|
||||
if len(page) == 0 {
|
||||
return nil, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
pageIDs := make([]int64, len(page))
|
||||
pageIdx := make(map[int64]int, len(page))
|
||||
for i, e := range page {
|
||||
pageIDs[i] = e.id
|
||||
pageIdx[e.id] = i
|
||||
}
|
||||
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.IDIn(pageIDs...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outGroups := make([]service.Group, len(page))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
c := counts[g.ID]
|
||||
g.AccountCount = c.Total
|
||||
g.ActiveAccountCount = c.Active
|
||||
g.RateLimitedAccountCount = c.RateLimited
|
||||
if idx, ok := pageIdx[g.ID]; ok {
|
||||
outGroups[idx] = *g
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
|
||||
@ -44,6 +44,33 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "image_input_size", "character varying", 32, true)
|
||||
requireColumn(t, tx, "usage_logs", "image_output_size", "character varying", 32, true)
|
||||
requireColumn(t, tx, "usage_logs", "image_size_source", "character varying", 16, true)
|
||||
requireColumn(t, tx, "usage_logs", "image_size_breakdown", "jsonb", 0, true)
|
||||
requireConstraintDefinitionContains(
|
||||
t,
|
||||
tx,
|
||||
"usage_logs",
|
||||
"usage_logs_image_size_source_check",
|
||||
"image_size_source",
|
||||
"'output'",
|
||||
"'input'",
|
||||
"'default'",
|
||||
"'legacy'",
|
||||
)
|
||||
requireConstraintDefinitionContains(
|
||||
t,
|
||||
tx,
|
||||
"usage_logs",
|
||||
"usage_logs_image_billing_size_check",
|
||||
"image_count",
|
||||
"image_size IS NOT NULL",
|
||||
"'1K'",
|
||||
"'2K'",
|
||||
"'4K'",
|
||||
"'mixed'",
|
||||
)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
|
||||
@ -30,6 +30,7 @@ func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemC
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays).
|
||||
SetNillableExpiresAt(code.ExpiresAt).
|
||||
SetNillableUsedBy(code.UsedBy).
|
||||
SetNillableUsedAt(code.UsedAt).
|
||||
SetNillableGroupID(code.GroupID).
|
||||
@ -56,6 +57,7 @@ func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays).
|
||||
SetNillableExpiresAt(c.ExpiresAt).
|
||||
SetNillableUsedBy(c.UsedBy).
|
||||
SetNillableUsedAt(c.UsedAt).
|
||||
SetNillableGroupID(c.GroupID)
|
||||
@ -107,7 +109,28 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(redeemcode.StatusEQ(status))
|
||||
now := time.Now()
|
||||
switch status {
|
||||
case service.StatusExpired:
|
||||
q = q.Where(redeemcode.Or(
|
||||
redeemcode.StatusEQ(service.StatusExpired),
|
||||
redeemcode.And(
|
||||
redeemcode.StatusEQ(service.StatusUnused),
|
||||
redeemcode.ExpiresAtNotNil(),
|
||||
redeemcode.ExpiresAtLTE(now),
|
||||
),
|
||||
))
|
||||
case service.StatusUnused:
|
||||
q = q.Where(
|
||||
redeemcode.StatusEQ(service.StatusUnused),
|
||||
redeemcode.Or(
|
||||
redeemcode.ExpiresAtIsNil(),
|
||||
redeemcode.ExpiresAtGT(now),
|
||||
),
|
||||
)
|
||||
default:
|
||||
q = q.Where(redeemcode.StatusEQ(status))
|
||||
}
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(
|
||||
@ -158,6 +181,8 @@ func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Sele
|
||||
field = redeemcode.FieldUsedAt
|
||||
case "created_at":
|
||||
field = redeemcode.FieldCreatedAt
|
||||
case "expires_at":
|
||||
field = redeemcode.FieldExpiresAt
|
||||
case "code":
|
||||
field = redeemcode.FieldCode
|
||||
default:
|
||||
@ -194,6 +219,11 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
|
||||
} else {
|
||||
up.ClearGroupID()
|
||||
}
|
||||
if code.ExpiresAt != nil {
|
||||
up.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
up.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := up.Save(ctx)
|
||||
if err != nil {
|
||||
@ -307,6 +337,7 @@ func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||
UsedAt: m.UsedAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
GroupID: m.GroupID,
|
||||
ValidityDays: m.ValidityDays,
|
||||
}
|
||||
|
||||
@ -51,11 +51,13 @@ func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
|
||||
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
expiresAt := time.Now().UTC().Add(2 * time.Hour)
|
||||
code := &service.RedeemCode{
|
||||
Code: "TEST-CREATE",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 100,
|
||||
Status: service.StatusUnused,
|
||||
Code: "TEST-CREATE",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 100,
|
||||
Status: service.StatusUnused,
|
||||
ExpiresAt: &expiresAt,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, code)
|
||||
@ -65,6 +67,8 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("TEST-CREATE", got.Code)
|
||||
s.Require().NotNil(got.ExpiresAt)
|
||||
s.Require().WithinDuration(expiresAt, *got.ExpiresAt, time.Second)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||
@ -166,6 +170,23 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
s.Require().Equal(service.StatusUsed, codes[0].Status)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_StatusExpiredByExpiresAt() {
|
||||
past := time.Now().UTC().Add(-time.Hour)
|
||||
future := time.Now().UTC().Add(time.Hour)
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-EXPIRED-BY-TIME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &past}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED-FUTURE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &future}))
|
||||
|
||||
expired, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusExpired, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(expired, 1)
|
||||
s.Require().Equal("STAT-EXPIRED-BY-TIME", expired[0].Code)
|
||||
|
||||
unused, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUnused, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(unused, 1)
|
||||
s.Require().Equal("STAT-UNUSED-FUTURE", unused[0].Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
@ -546,6 +546,8 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
|
||||
"responses_websockets_v2_enabled",
|
||||
"openai_ws_enabled",
|
||||
"openai_ws_force_http",
|
||||
"openai_responses_mode",
|
||||
"openai_responses_supported",
|
||||
}
|
||||
filtered := make(map[string]any)
|
||||
for _, key := range keys {
|
||||
|
||||
@ -18,6 +18,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
"openai_ws_force_http": true,
|
||||
"openai_responses_mode": "force_chat_completions",
|
||||
"openai_responses_supported": false,
|
||||
"mixed_scheduling": true,
|
||||
"unused_large_field": "drop-me",
|
||||
},
|
||||
@ -28,6 +30,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
|
||||
require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"])
|
||||
require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"])
|
||||
require.Equal(t, true, got.Extra["openai_ws_force_http"])
|
||||
require.Equal(t, "force_chat_completions", got.Extra["openai_responses_mode"])
|
||||
require.Equal(t, false, got.Extra["openai_responses_supported"])
|
||||
require.Equal(t, true, got.Extra["mixed_scheduling"])
|
||||
require.Nil(t, got.Extra["unused_large_field"])
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, image_input_size, image_output_size, image_size_source, image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
|
||||
|
||||
// usageLogInsertArgTypes must stay in the same order as:
|
||||
// 1. prepareUsageLogInsert().args
|
||||
@ -73,6 +73,10 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text", // ip_address
|
||||
"integer", // image_count
|
||||
"text", // image_size
|
||||
"text", // image_input_size
|
||||
"text", // image_output_size
|
||||
"text", // image_size_source
|
||||
"jsonb", // image_size_breakdown
|
||||
"text", // service_tier
|
||||
"text", // reasoning_effort
|
||||
"text", // inbound_endpoint
|
||||
@ -92,6 +96,22 @@ const rawUsageLogModelColumn = "model"
|
||||
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
|
||||
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
|
||||
|
||||
// usageLogSuccessFilterUL 用于把"失败请求 usage log"(tokens=0、cost=0、不计费的占位记录)
|
||||
// 从统计性聚合中排除,避免污染 Dashboard / 用量拆分等指标。
|
||||
//
|
||||
// schema 中没有 success bool 列;新增列要做迁移,风险大;这里用 actual_cost > 0 作为代理:
|
||||
// 任何成功落账的请求都会产生 actual_cost(包括 token 计费、纯图片 token 计费、按次/按图计费),
|
||||
// 反之 failed-request usage log 的 actual_cost 为 0。
|
||||
// 早期版本用 4 项 token 和 > 0 判定会把"按次/按图计费"与"image_output_tokens 独立计费"的纯图片
|
||||
// 请求误判为失败,导致这部分请求从用量统计里消失,故改用 actual_cost。
|
||||
// 配合 `FROM usage_logs ul` JOIN 查询使用。
|
||||
const usageLogSuccessFilterUL = "ul.actual_cost > 0"
|
||||
|
||||
// usageLogEffectivePlatformExpr 用于按"有效平台"维度聚合 usage_logs:
|
||||
// 优先取请求实际走的分组 platform,若分组未设置 platform 再 fallback 到 account.platform。
|
||||
// 配套要求查询里 LEFT JOIN groups g ON g.id = ul.group_id 与 LEFT JOIN accounts a ON a.id = ul.account_id。
|
||||
const usageLogEffectivePlatformExpr = "COALESCE(NULLIF(g.platform,''), a.platform)"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
"hour": "YYYY-MM-DD HH24:00",
|
||||
@ -120,6 +140,24 @@ func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model
|
||||
return conditions, args
|
||||
}
|
||||
|
||||
func appendUsageLogBillingModeWhereCondition(conditions []string, args []any, billingMode string) ([]string, []any) {
|
||||
mode := strings.TrimSpace(billingMode)
|
||||
if mode == "" {
|
||||
return conditions, args
|
||||
}
|
||||
placeholder := fmt.Sprintf("$%d", len(args)+1)
|
||||
switch service.BillingMode(mode) {
|
||||
case service.BillingModeImage:
|
||||
conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR COALESCE(image_count, 0) > 0)", placeholder))
|
||||
case service.BillingModeToken:
|
||||
conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))", placeholder))
|
||||
default:
|
||||
conditions = append(conditions, fmt.Sprintf("billing_mode = %s", placeholder))
|
||||
}
|
||||
args = append(args, mode)
|
||||
return conditions, args
|
||||
}
|
||||
|
||||
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
|
||||
// compatibility with historical rows. Requested/upstream analytics must use
|
||||
// resolveModelDimensionExpression instead.
|
||||
@ -352,6 +390,10 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -369,7 +411,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@ -790,6 +832,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -803,7 +849,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*46)
|
||||
args := make([]any, 0, len(keys)*50)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
@ -867,6 +913,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -915,6 +965,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -1003,6 +1057,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -1016,7 +1074,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*46)
|
||||
args := make([]any, 0, len(preparedList)*50)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@ -1077,6 +1135,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -1125,6 +1187,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -1181,6 +1247,10 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
image_input_size,
|
||||
image_output_size,
|
||||
image_size_source,
|
||||
image_size_breakdown,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
@ -1198,7 +1268,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$10, $11, $12, $13,
|
||||
$14, $15, $16, $17,
|
||||
$18, $19, $20, $21, $22, $23,
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@ -1225,6 +1295,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
userAgent := nullString(log.UserAgent)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
imageInputSize := nullString(log.ImageInputSize)
|
||||
imageOutputSize := nullString(log.ImageOutputSize)
|
||||
imageSizeSource := nullString(log.ImageSizeSource)
|
||||
imageSizeBreakdown := nullStringIntMapJSON(log.ImageSizeBreakdown)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||
@ -1285,6 +1359,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
imageInputSize,
|
||||
imageOutputSize,
|
||||
imageSizeSource,
|
||||
imageSizeBreakdown,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
inboundEndpoint,
|
||||
@ -2352,6 +2430,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats = usagestats.UserDashboardStats
|
||||
|
||||
// PlatformDashboardStats 单平台用量明细
|
||||
type PlatformDashboardStats = usagestats.PlatformDashboardStats
|
||||
|
||||
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
||||
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||||
stats := &UserDashboardStats{}
|
||||
@ -2447,6 +2528,57 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
// 按"有效平台"维度拆分(group.platform 优先,否则 account.platform)。
|
||||
// 与 ops 路径口径一致;HAVING 过滤掉无法确定平台的行(避免出现空字符串平台)。
|
||||
// 与上面 totalStatsQuery/todayStatsQuery 的总值可能略微差异,原因有二:
|
||||
// 1) 无平台归属的极少数行(group/account 都没 platform)会被 HAVING 排除;
|
||||
// 2) usageLogSuccessFilterUL 会把 actual_cost = 0 的失败 placeholder 行排除,
|
||||
// 而 totalStatsQuery/todayStatsQuery 没有这层过滤、会把这些行的 request 计数算进去。
|
||||
platformQuery := `
|
||||
SELECT
|
||||
` + usageLogEffectivePlatformExpr + ` as platform,
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(ul.actual_cost), 0) as total_actual_cost,
|
||||
COUNT(*) FILTER (WHERE ul.created_at >= $2) as today_requests,
|
||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) FILTER (WHERE ul.created_at >= $2), 0) as today_tokens,
|
||||
COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2), 0) as today_actual_cost
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.user_id = $1
|
||||
AND ` + usageLogSuccessFilterUL + `
|
||||
GROUP BY ` + usageLogEffectivePlatformExpr + `
|
||||
HAVING ` + usageLogEffectivePlatformExpr + ` IS NOT NULL AND ` + usageLogEffectivePlatformExpr + ` <> ''
|
||||
ORDER BY total_actual_cost DESC
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, platformQuery, userID, today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var p PlatformDashboardStats
|
||||
if err := rows.Scan(
|
||||
&p.Platform,
|
||||
&p.TotalRequests,
|
||||
&p.TotalTokens,
|
||||
&p.TotalActualCost,
|
||||
&p.TodayRequests,
|
||||
&p.TodayTokens,
|
||||
&p.TodayActualCost,
|
||||
); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
stats.ByPlatform = append(stats.ByPlatform, p)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@ -2662,10 +2794,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
}
|
||||
if filters.BillingMode != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
|
||||
args = append(args, filters.BillingMode)
|
||||
}
|
||||
conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode)
|
||||
if filters.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||||
args = append(args, *filters.StartTime)
|
||||
@ -2710,6 +2839,9 @@ type UsageStats = usagestats.UsageStats
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
// PlatformUsage represents per-platform usage breakdown
|
||||
type PlatformUsage = usagestats.PlatformUsage
|
||||
|
||||
func normalizePositiveInt64IDs(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
@ -2750,15 +2882,21 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
result[id] = &BatchUserUsageStats{UserID: id}
|
||||
}
|
||||
|
||||
// GROUP BY (user_id, effective_platform) 一次查询同时得到总值与按平台拆分。
|
||||
// 应用层把同一 user_id 的多行累加为总值,并把非空 platform 行收集到 ByPlatform。
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1)
|
||||
AND created_at >= LEAST($2, $4)
|
||||
GROUP BY user_id
|
||||
ul.user_id,
|
||||
` + usageLogEffectivePlatformExpr + ` as platform,
|
||||
COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2 AND ul.created_at < $3), 0) as total_cost,
|
||||
COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $4), 0) as today_cost
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.user_id = ANY($1)
|
||||
AND ul.created_at >= LEAST($2, $4)
|
||||
AND ` + usageLogSuccessFilterUL + `
|
||||
GROUP BY ul.user_id, ` + usageLogEffectivePlatformExpr + `
|
||||
`
|
||||
today := timezone.Today()
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
|
||||
@ -2767,15 +2905,25 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var platform sql.NullString
|
||||
var total float64
|
||||
var todayTotal float64
|
||||
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
|
||||
if err := rows.Scan(&userID, &platform, &total, &todayTotal); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[userID]; ok {
|
||||
stats.TotalActualCost = total
|
||||
stats.TodayActualCost = todayTotal
|
||||
stats, ok := result[userID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
stats.TotalActualCost += total
|
||||
stats.TodayActualCost += todayTotal
|
||||
if platform.Valid && platform.String != "" {
|
||||
stats.ByPlatform = append(stats.ByPlatform, PlatformUsage{
|
||||
Platform: platform.String,
|
||||
TotalActualCost: total,
|
||||
TodayActualCost: todayTotal,
|
||||
})
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
@ -3363,10 +3511,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
}
|
||||
if filters.BillingMode != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
|
||||
args = append(args, filters.BillingMode)
|
||||
}
|
||||
conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode)
|
||||
if filters.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||||
args = append(args, *filters.StartTime)
|
||||
@ -4084,6 +4229,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
ipAddress sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
imageInputSize sql.NullString
|
||||
imageOutputSize sql.NullString
|
||||
imageSizeSource sql.NullString
|
||||
imageSizeBreakdown sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
inboundEndpoint sql.NullString
|
||||
@ -4134,6 +4283,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&ipAddress,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&imageInputSize,
|
||||
&imageOutputSize,
|
||||
&imageSizeSource,
|
||||
&imageSizeBreakdown,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&inboundEndpoint,
|
||||
@ -4212,6 +4365,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
if imageInputSize.Valid {
|
||||
log.ImageInputSize = &imageInputSize.String
|
||||
}
|
||||
if imageOutputSize.Valid {
|
||||
log.ImageOutputSize = &imageOutputSize.String
|
||||
}
|
||||
if imageSizeSource.Valid {
|
||||
log.ImageSizeSource = &imageSizeSource.String
|
||||
}
|
||||
log.ImageSizeBreakdown = stringIntMapFromNullJSON(imageSizeBreakdown)
|
||||
if serviceTier.Valid {
|
||||
log.ServiceTier = &serviceTier.String
|
||||
}
|
||||
@ -4378,6 +4541,31 @@ func nullString(v *string) sql.NullString {
|
||||
return sql.NullString{String: *v, Valid: true}
|
||||
}
|
||||
|
||||
func nullStringIntMapJSON(v map[string]int) any {
|
||||
if len(v) == 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func stringIntMapFromNullJSON(v sql.NullString) map[string]int {
|
||||
if !v.Valid || strings.TrimSpace(v.String) == "" {
|
||||
return nil
|
||||
}
|
||||
var out map[string]int
|
||||
if err := json.Unmarshal([]byte(v.String), &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func coalesceTrimmedString(v sql.NullString, fallback string) string {
|
||||
if v.Valid && strings.TrimSpace(v.String) != "" {
|
||||
return v.String
|
||||
|
||||
@ -76,6 +76,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // ip_address
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // image_input_size
|
||||
sqlmock.AnyArg(), // image_output_size
|
||||
sqlmock.AnyArg(), // image_size_source
|
||||
sqlmock.AnyArg(), // image_size_breakdown
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
@ -155,6 +159,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(), // image_input_size
|
||||
sqlmock.AnyArg(), // image_output_size
|
||||
sqlmock.AnyArg(), // image_size_source
|
||||
sqlmock.AnyArg(), // image_size_breakdown
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
@ -230,12 +238,74 @@ func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
|
||||
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
|
||||
}
|
||||
|
||||
func TestPrepareUsageLogInsert_PersistsImageSizeMetadata(t *testing.T) {
|
||||
imageSize := "4K"
|
||||
inputSize := "1024x1024"
|
||||
outputSize := "3840x2160"
|
||||
source := "output"
|
||||
prepared := prepareUsageLogInsert(&service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-image-metadata",
|
||||
Model: "gpt-image-2",
|
||||
RequestedModel: "gpt-image-2",
|
||||
ImageCount: 2,
|
||||
ImageSize: &imageSize,
|
||||
ImageInputSize: &inputSize,
|
||||
ImageOutputSize: &outputSize,
|
||||
ImageSizeSource: &source,
|
||||
ImageSizeBreakdown: map[string]int{"1K": 1, "4K": 1},
|
||||
CreatedAt: time.Date(2025, 1, 6, 12, 0, 0, 0, time.UTC),
|
||||
})
|
||||
|
||||
require.Equal(t, sql.NullString{String: imageSize, Valid: true}, prepared.args[34])
|
||||
require.Equal(t, sql.NullString{String: inputSize, Valid: true}, prepared.args[35])
|
||||
require.Equal(t, sql.NullString{String: outputSize, Valid: true}, prepared.args[36])
|
||||
require.Equal(t, sql.NullString{String: source, Valid: true}, prepared.args[37])
|
||||
breakdownJSON, ok := prepared.args[38].(string)
|
||||
require.True(t, ok)
|
||||
require.JSONEq(t, `{"1K":1,"4K":1}`, breakdownJSON)
|
||||
}
|
||||
|
||||
func TestCoalesceTrimmedString(t *testing.T) {
|
||||
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
|
||||
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
|
||||
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
|
||||
}
|
||||
|
||||
func TestAppendUsageLogBillingModeWhereCondition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
billingMode string
|
||||
wantCondition string
|
||||
}{
|
||||
{
|
||||
name: "image includes legacy image rows",
|
||||
billingMode: string(service.BillingModeImage),
|
||||
wantCondition: "(billing_mode = $1 OR COALESCE(image_count, 0) > 0)",
|
||||
},
|
||||
{
|
||||
name: "token includes legacy non-image rows",
|
||||
billingMode: string(service.BillingModeToken),
|
||||
wantCondition: "(billing_mode = $1 OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))",
|
||||
},
|
||||
{
|
||||
name: "per request remains exact",
|
||||
billingMode: string(service.BillingModePerRequest),
|
||||
wantCondition: "billing_mode = $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conditions, args := appendUsageLogBillingModeWhereCondition(nil, nil, tt.billingMode)
|
||||
require.Equal(t, []string{tt.wantCondition}, conditions)
|
||||
require.Equal(t, []any{tt.billingMode}, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func anySliceToDriverValues(values []any) []driver.Value {
|
||||
out := make([]driver.Value, 0, len(values))
|
||||
for _, value := range values {
|
||||
@ -528,6 +598,63 @@ func (s usageLogScannerStub) Scan(dest ...any) error {
|
||||
}
|
||||
|
||||
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
t.Run("image_size_metadata_is_scanned", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(4),
|
||||
int64(13),
|
||||
int64(23),
|
||||
int64(33),
|
||||
sql.NullString{Valid: true, String: "req-image-metadata"},
|
||||
"gpt-image-2",
|
||||
sql.NullString{Valid: true, String: "gpt-image-2"},
|
||||
sql.NullString{},
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
0, 0, 0, 0, 0, 0,
|
||||
0, 0.0, // image_output_tokens, image_output_cost
|
||||
0.0, 0.0, 0.0, 0.0, 0.8, 0.8,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeSync),
|
||||
false,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
2,
|
||||
sql.NullString{Valid: true, String: "4K"},
|
||||
sql.NullString{Valid: true, String: "1024x1024"},
|
||||
sql.NullString{Valid: true, String: "3840x2160"},
|
||||
sql.NullString{Valid: true, String: "output"},
|
||||
sql.NullString{Valid: true, String: `{"4K":2}`},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullFloat64{},
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, log.ImageCount)
|
||||
require.NotNil(t, log.ImageSize)
|
||||
require.Equal(t, "4K", *log.ImageSize)
|
||||
require.NotNil(t, log.ImageInputSize)
|
||||
require.Equal(t, "1024x1024", *log.ImageInputSize)
|
||||
require.NotNil(t, log.ImageOutputSize)
|
||||
require.Equal(t, "3840x2160", *log.ImageOutputSize)
|
||||
require.NotNil(t, log.ImageSizeSource)
|
||||
require.Equal(t, "output", *log.ImageSizeSource)
|
||||
require.Equal(t, map[string]int{"4K": 2}, log.ImageSizeBreakdown)
|
||||
})
|
||||
|
||||
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
@ -567,6 +694,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{}, // image_input_size
|
||||
sql.NullString{}, // image_output_size
|
||||
sql.NullString{}, // image_size_source
|
||||
sql.NullString{}, // image_size_breakdown
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
@ -615,6 +746,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{}, // image_input_size
|
||||
sql.NullString{}, // image_output_size
|
||||
sql.NullString{}, // image_size_source
|
||||
sql.NullString{}, // image_size_breakdown
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
@ -663,6 +798,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{}, // image_input_size
|
||||
sql.NullString{}, // image_output_size
|
||||
sql.NullString{}, // image_size_source
|
||||
sql.NullString{}, // image_size_breakdown
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
|
||||
@ -334,7 +334,8 @@ func normalizeEmailAuthIdentitySubject(email string) string {
|
||||
}
|
||||
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
|
||||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, service.DingTalkConnectSyntheticEmailDomain) {
|
||||
return ""
|
||||
}
|
||||
return normalized
|
||||
@ -956,7 +957,7 @@ func userSignupSourceOrDefault(signupSource string) string {
|
||||
switch strings.TrimSpace(strings.ToLower(signupSource)) {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc":
|
||||
case "linuxdo", "wechat", "oidc", "dingtalk":
|
||||
return strings.TrimSpace(strings.ToLower(signupSource))
|
||||
default:
|
||||
return "email"
|
||||
|
||||
@ -68,6 +68,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"linuxdo_bound": false,
|
||||
"oidc_bound": false,
|
||||
"wechat_bound": false,
|
||||
"dingtalk_bound": false,
|
||||
"identities": {
|
||||
"email": {
|
||||
"provider": "email",
|
||||
@ -104,6 +105,14 @@ func TestAPIContracts(t *testing.T) {
|
||||
"can_bind": true,
|
||||
"can_unbind": false,
|
||||
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
|
||||
},
|
||||
"dingtalk": {
|
||||
"provider": "dingtalk",
|
||||
"bound": false,
|
||||
"bound_count": 0,
|
||||
"can_bind": true,
|
||||
"can_unbind": false,
|
||||
"bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
|
||||
}
|
||||
},
|
||||
"identity_bindings": {
|
||||
@ -142,6 +151,14 @@ func TestAPIContracts(t *testing.T) {
|
||||
"can_bind": true,
|
||||
"can_unbind": false,
|
||||
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
|
||||
},
|
||||
"dingtalk": {
|
||||
"provider": "dingtalk",
|
||||
"bound": false,
|
||||
"bound_count": 0,
|
||||
"can_bind": true,
|
||||
"can_unbind": false,
|
||||
"bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
|
||||
}
|
||||
},
|
||||
"auth_bindings": {
|
||||
@ -180,6 +197,14 @@ func TestAPIContracts(t *testing.T) {
|
||||
"can_bind": true,
|
||||
"can_unbind": false,
|
||||
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
|
||||
},
|
||||
"dingtalk": {
|
||||
"provider": "dingtalk",
|
||||
"bound": false,
|
||||
"bound_count": 0,
|
||||
"can_bind": true,
|
||||
"can_unbind": false,
|
||||
"bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
|
||||
}
|
||||
},
|
||||
"run_mode": "standard"
|
||||
@ -554,6 +579,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"first_token_ms": 50,
|
||||
"image_count": 0,
|
||||
"image_size": null,
|
||||
"image_input_size": null,
|
||||
"image_output_size": null,
|
||||
"image_size_source": null,
|
||||
"image_size_breakdown": null,
|
||||
"media_type": null,
|
||||
"cache_ttl_overridden": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
@ -672,6 +701,22 @@ func TestAPIContracts(t *testing.T) {
|
||||
"linuxdo_connect_client_id": "",
|
||||
"linuxdo_connect_client_secret_configured": false,
|
||||
"linuxdo_connect_redirect_url": "",
|
||||
"dingtalk_connect_enabled": false,
|
||||
"dingtalk_connect_bypass_registration": false,
|
||||
"dingtalk_connect_client_id": "",
|
||||
"dingtalk_connect_client_secret_configured": false,
|
||||
"dingtalk_connect_redirect_url": "",
|
||||
"dingtalk_connect_internal_corp_id": "",
|
||||
"dingtalk_connect_corp_restriction_policy": "",
|
||||
"dingtalk_connect_sync_corp_email": false,
|
||||
"dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email",
|
||||
"dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱",
|
||||
"dingtalk_connect_sync_dept": false,
|
||||
"dingtalk_connect_sync_dept_attr_key": "dingtalk_department",
|
||||
"dingtalk_connect_sync_dept_attr_name": "钉钉部门",
|
||||
"dingtalk_connect_sync_display_name": false,
|
||||
"dingtalk_connect_sync_display_name_attr_key": "dingtalk_name",
|
||||
"dingtalk_connect_sync_display_name_attr_name": "钉钉姓名",
|
||||
"oidc_connect_enabled": false,
|
||||
"oidc_connect_provider_name": "OIDC",
|
||||
"oidc_connect_client_id": "",
|
||||
@ -744,6 +789,11 @@ func TestAPIContracts(t *testing.T) {
|
||||
"auth_source_default_wechat_subscriptions": [],
|
||||
"auth_source_default_wechat_grant_on_signup": false,
|
||||
"auth_source_default_wechat_grant_on_first_bind": false,
|
||||
"auth_source_default_dingtalk_balance": 0,
|
||||
"auth_source_default_dingtalk_concurrency": 5,
|
||||
"auth_source_default_dingtalk_subscriptions": [],
|
||||
"auth_source_default_dingtalk_grant_on_signup": false,
|
||||
"auth_source_default_dingtalk_grant_on_first_bind": false,
|
||||
"force_email_on_third_party_signup": false,
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
@ -784,14 +834,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": true,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
"rules": []
|
||||
},
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": [],
|
||||
@ -917,6 +960,22 @@ func TestAPIContracts(t *testing.T) {
|
||||
"linuxdo_connect_client_id": "",
|
||||
"linuxdo_connect_client_secret_configured": false,
|
||||
"linuxdo_connect_redirect_url": "",
|
||||
"dingtalk_connect_enabled": false,
|
||||
"dingtalk_connect_bypass_registration": false,
|
||||
"dingtalk_connect_client_id": "",
|
||||
"dingtalk_connect_client_secret_configured": false,
|
||||
"dingtalk_connect_redirect_url": "",
|
||||
"dingtalk_connect_internal_corp_id": "",
|
||||
"dingtalk_connect_corp_restriction_policy": "",
|
||||
"dingtalk_connect_sync_corp_email": false,
|
||||
"dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email",
|
||||
"dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱",
|
||||
"dingtalk_connect_sync_dept": false,
|
||||
"dingtalk_connect_sync_dept_attr_key": "dingtalk_department",
|
||||
"dingtalk_connect_sync_dept_attr_name": "钉钉部门",
|
||||
"dingtalk_connect_sync_display_name": false,
|
||||
"dingtalk_connect_sync_display_name_attr_key": "dingtalk_name",
|
||||
"dingtalk_connect_sync_display_name_attr_name": "钉钉姓名",
|
||||
"oidc_connect_enabled": true,
|
||||
"oidc_connect_provider_name": "ConfigOIDC",
|
||||
"oidc_connect_client_id": "oidc-config-client",
|
||||
@ -999,14 +1058,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": false,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
"rules": []
|
||||
},
|
||||
"payment_enabled": false,
|
||||
"payment_min_amount": 0,
|
||||
@ -1084,6 +1136,11 @@ func TestAPIContracts(t *testing.T) {
|
||||
"auth_source_default_wechat_subscriptions": [],
|
||||
"auth_source_default_wechat_grant_on_signup": false,
|
||||
"auth_source_default_wechat_grant_on_first_bind": false,
|
||||
"auth_source_default_dingtalk_balance": 0,
|
||||
"auth_source_default_dingtalk_concurrency": 5,
|
||||
"auth_source_default_dingtalk_subscriptions": [],
|
||||
"auth_source_default_dingtalk_grant_on_signup": false,
|
||||
"auth_source_default_dingtalk_grant_on_first_bind": false,
|
||||
"force_email_on_third_party_signup": false
|
||||
}
|
||||
}`,
|
||||
@ -1194,10 +1251,10 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
|
||||
@ -92,6 +92,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
|
||||
if !allowed {
|
||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
return
|
||||
}
|
||||
|
||||
@ -333,6 +333,15 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
var markedBusinessLimited bool
|
||||
var businessLimitedReason string
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Next()
|
||||
markedBusinessLimited = service.HasOpsClientBusinessLimited(c)
|
||||
if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok {
|
||||
businessLimitedReason, _ = v.(string)
|
||||
}
|
||||
})
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
@ -349,6 +358,8 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||
require.True(t, markedBusinessLimited)
|
||||
require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
|
||||
|
||||
@ -42,15 +42,19 @@ func backendModeAllowsAuthPath(path string) bool {
|
||||
"/auth/oauth/oidc/callback",
|
||||
"/auth/oauth/github/callback",
|
||||
"/auth/oauth/google/callback",
|
||||
"/auth/oauth/dingtalk/callback",
|
||||
"/auth/oauth/linuxdo/complete-registration",
|
||||
"/auth/oauth/wechat/complete-registration",
|
||||
"/auth/oauth/oidc/complete-registration",
|
||||
"/auth/oauth/dingtalk/complete-registration",
|
||||
"/auth/oauth/linuxdo/create-account",
|
||||
"/auth/oauth/wechat/create-account",
|
||||
"/auth/oauth/oidc/create-account",
|
||||
"/auth/oauth/dingtalk/create-account",
|
||||
"/auth/oauth/linuxdo/bind-login",
|
||||
"/auth/oauth/wechat/bind-login",
|
||||
"/auth/oauth/oidc/bind-login",
|
||||
"/auth/oauth/dingtalk/bind-login",
|
||||
} {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
return true
|
||||
|
||||
@ -270,6 +270,36 @@ func TestBackendModeAuthGuard(t *testing.T) {
|
||||
path: "/api/v1/auth/oauth/google/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_dingtalk_oauth_start",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/dingtalk/start",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_dingtalk_oauth_callback",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/dingtalk/callback",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_dingtalk_complete_registration",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/dingtalk/complete-registration",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_dingtalk_create_account",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/dingtalk/create-account",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_dingtalk_bind_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/oauth/dingtalk/bind-login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_oauth_pending_exchange",
|
||||
enabled: "true",
|
||||
|
||||
@ -303,6 +303,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
|
||||
accounts.POST("/:id/models/sync-upstream", h.Admin.Account.SyncUpstreamModels)
|
||||
accounts.POST("/batch", h.Admin.Account.BatchCreate)
|
||||
accounts.GET("/data", h.Admin.Account.ExportData)
|
||||
accounts.POST("/data", h.Admin.Account.ImportData)
|
||||
|
||||
@ -182,6 +182,32 @@ func RegisterAuthRoutes(
|
||||
}),
|
||||
h.Auth.CreateOIDCOAuthAccount,
|
||||
)
|
||||
auth.GET("/oauth/dingtalk/start", h.Auth.DingTalkOAuthStart)
|
||||
auth.GET("/oauth/dingtalk/bind/start", func(c *gin.Context) {
|
||||
query := c.Request.URL.Query()
|
||||
query.Set("intent", "bind_current_user")
|
||||
c.Request.URL.RawQuery = query.Encode()
|
||||
h.Auth.DingTalkOAuthStart(c)
|
||||
})
|
||||
auth.GET("/oauth/dingtalk/callback", h.Auth.DingTalkOAuthCallback)
|
||||
auth.POST("/oauth/dingtalk/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-dingtalk-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteDingTalkOAuthRegistration,
|
||||
)
|
||||
auth.POST("/oauth/dingtalk/bind-login",
|
||||
rateLimiter.LimitWithOptions("oauth-dingtalk-bind-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.BindDingTalkOAuthLogin,
|
||||
)
|
||||
auth.POST("/oauth/dingtalk/create-account",
|
||||
rateLimiter.LimitWithOptions("oauth-dingtalk-create-account", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CreateDingTalkOAuthAccount,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
|
||||
50
backend/internal/service/account_credentials_redact.go
Normal file
50
backend/internal/service/account_credentials_redact.go
Normal file
@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
// SensitiveCredentialKeys 列出 Account.Credentials JSON map 中绝不允许返回到前端的子键。
|
||||
// dto 层做响应脱敏、service 层做更新合并都引用此清单——新增凭证类型时务必同步。
|
||||
var SensitiveCredentialKeys = []string{
|
||||
// OAuth
|
||||
"access_token", "refresh_token", "id_token",
|
||||
// API Key 类
|
||||
"api_key", "session_key", "cookie",
|
||||
// 云服务凭据
|
||||
"aws_secret_access_key", "aws_session_token",
|
||||
"service_account_json", "service_account", "private_key",
|
||||
}
|
||||
|
||||
var sensitiveCredentialKeySet = func() map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(SensitiveCredentialKeys))
|
||||
for _, k := range SensitiveCredentialKeys {
|
||||
m[k] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}()
|
||||
|
||||
// IsSensitiveCredentialKey 判断指定键是否为敏感凭证子键。
|
||||
func IsSensitiveCredentialKey(key string) bool {
|
||||
_, ok := sensitiveCredentialKeySet[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// MergePreservingSensitiveCreds 把 incoming 写入 existing 之上,但敏感子键采用"incoming 没提供就保留 existing"
|
||||
// 的语义。返回新的 map,不修改入参。
|
||||
//
|
||||
// 用途:前端编辑账号通常采用"全对象 PUT"模式;脱敏后前端 spread 旧 credentials 时不会带上敏感键,
|
||||
// 直接覆盖会清空已有 token。此函数保证:
|
||||
// - 非敏感键:完全由 incoming 决定(用户可以编辑、删除非敏感字段)。
|
||||
// - 敏感键:incoming 显式提供则覆盖(用户主动旋转 token),否则保留 existing。
|
||||
func MergePreservingSensitiveCreds(existing, incoming map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(incoming)+len(SensitiveCredentialKeys))
|
||||
for k, v := range incoming {
|
||||
out[k] = v
|
||||
}
|
||||
for _, key := range SensitiveCredentialKeys {
|
||||
if _, hasIncoming := incoming[key]; hasIncoming {
|
||||
continue
|
||||
}
|
||||
if existingVal, ok := existing[key]; ok {
|
||||
out[key] = existingVal
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
90
backend/internal/service/account_credentials_redact_test.go
Normal file
90
backend/internal/service/account_credentials_redact_test.go
Normal file
@ -0,0 +1,90 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergePreservingSensitiveCreds_PreservesSensitiveWhenIncomingMissing(t *testing.T) {
|
||||
existing := map[string]any{
|
||||
"refresh_token": "rt-old",
|
||||
"access_token": "at-old",
|
||||
"api_key": "sk-old",
|
||||
"base_url": "https://old.example.com",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
"base_url": "https://new.example.com",
|
||||
"model_mapping": map[string]any{"foo": "bar"},
|
||||
}
|
||||
|
||||
out := MergePreservingSensitiveCreds(existing, incoming)
|
||||
|
||||
require.Equal(t, "rt-old", out["refresh_token"], "incoming 没传 refresh_token,应保留 existing")
|
||||
require.Equal(t, "at-old", out["access_token"])
|
||||
require.Equal(t, "sk-old", out["api_key"])
|
||||
require.Equal(t, "https://new.example.com", out["base_url"], "非敏感键由 incoming 决定")
|
||||
require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
|
||||
}
|
||||
|
||||
func TestMergePreservingSensitiveCreds_OverwritesWhenIncomingProvidesSensitive(t *testing.T) {
|
||||
existing := map[string]any{
|
||||
"refresh_token": "rt-old",
|
||||
"api_key": "sk-old",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
"refresh_token": "rt-new",
|
||||
// 显式没传 api_key —— 应保留
|
||||
}
|
||||
out := MergePreservingSensitiveCreds(existing, incoming)
|
||||
require.Equal(t, "rt-new", out["refresh_token"], "incoming 显式传入应覆盖")
|
||||
require.Equal(t, "sk-old", out["api_key"], "incoming 没传应保留")
|
||||
}
|
||||
|
||||
func TestMergePreservingSensitiveCreds_DoesNotMutateInputs(t *testing.T) {
|
||||
existing := map[string]any{"refresh_token": "rt"}
|
||||
incoming := map[string]any{"base_url": "x"}
|
||||
|
||||
_ = MergePreservingSensitiveCreds(existing, incoming)
|
||||
|
||||
require.Equal(t, "rt", existing["refresh_token"])
|
||||
require.NotContains(t, existing, "base_url")
|
||||
require.Equal(t, "x", incoming["base_url"])
|
||||
require.NotContains(t, incoming, "refresh_token")
|
||||
}
|
||||
|
||||
func TestMergePreservingSensitiveCreds_NilInputs(t *testing.T) {
|
||||
out := MergePreservingSensitiveCreds(nil, map[string]any{"base_url": "x"})
|
||||
require.Equal(t, "x", out["base_url"])
|
||||
require.NotContains(t, out, "refresh_token")
|
||||
|
||||
out2 := MergePreservingSensitiveCreds(map[string]any{"refresh_token": "rt"}, nil)
|
||||
require.Equal(t, "rt", out2["refresh_token"])
|
||||
}
|
||||
|
||||
func TestMergePreservingSensitiveCreds_NonSensitiveDeletionAllowed(t *testing.T) {
|
||||
existing := map[string]any{
|
||||
"refresh_token": "rt",
|
||||
"base_url": "https://old",
|
||||
"project_id": "p1",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
"base_url": "https://new",
|
||||
// 不带 project_id —— 等同删除(非敏感键由 incoming 决定)
|
||||
}
|
||||
out := MergePreservingSensitiveCreds(existing, incoming)
|
||||
require.Equal(t, "rt", out["refresh_token"], "敏感键保留")
|
||||
require.Equal(t, "https://new", out["base_url"])
|
||||
require.NotContains(t, out, "project_id", "非敏感键 incoming 不传 = 删除")
|
||||
}
|
||||
|
||||
func TestIsSensitiveCredentialKey(t *testing.T) {
|
||||
require.True(t, IsSensitiveCredentialKey("refresh_token"))
|
||||
require.True(t, IsSensitiveCredentialKey("api_key"))
|
||||
require.True(t, IsSensitiveCredentialKey("private_key"))
|
||||
require.False(t, IsSensitiveCredentialKey("base_url"))
|
||||
require.False(t, IsSensitiveCredentialKey(""))
|
||||
require.False(t, IsSensitiveCredentialKey("model_mapping"))
|
||||
}
|
||||
@ -397,6 +397,7 @@ type GenerateRedeemCodesInput struct {
|
||||
Value float64
|
||||
GroupID *int64 // 订阅类型专用:关联的分组ID
|
||||
ValidityDays int // 订阅类型专用:有效天数
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
type ProxyBatchDeleteResult struct {
|
||||
@ -1238,7 +1239,7 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
|
||||
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||
if providerType == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
|
||||
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, wechat, or dingtalk")
|
||||
}
|
||||
if providerKey == "" || providerSubject == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
|
||||
@ -1493,6 +1494,8 @@ func normalizeAdminAuthIdentityProviderType(input string) string {
|
||||
return "oidc"
|
||||
case "wechat":
|
||||
return "wechat"
|
||||
case "dingtalk":
|
||||
return "dingtalk"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
@ -2470,7 +2473,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Notes = normalizeAccountNotes(input.Notes)
|
||||
}
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = input.Credentials
|
||||
// 敏感子键采用"incoming 没提供就保留"的合并语义:前端响应已脱敏,
|
||||
// 全对象 PUT 编辑时不会再带回 token,避免覆盖时清空已有凭证。
|
||||
account.Credentials = MergePreservingSensitiveCreds(account.Credentials, input.Credentials)
|
||||
}
|
||||
// Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。
|
||||
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
|
||||
@ -2966,6 +2971,10 @@ func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*Redeem
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
|
||||
if input.ExpiresAt != nil && !input.ExpiresAt.After(time.Now()) {
|
||||
return nil, ErrRedeemCodeExpired
|
||||
}
|
||||
|
||||
// 如果是订阅类型,验证必须有 GroupID
|
||||
if input.Type == RedeemTypeSubscription {
|
||||
if input.GroupID == nil {
|
||||
@ -2988,10 +2997,11 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
||||
return nil, err
|
||||
}
|
||||
code := RedeemCode{
|
||||
Code: codeValue,
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: StatusUnused,
|
||||
Code: codeValue,
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: StatusUnused,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
}
|
||||
// 订阅类型专用字段
|
||||
if input.Type == RedeemTypeSubscription {
|
||||
|
||||
117
backend/internal/service/admin_service_credentials_merge_test.go
Normal file
117
backend/internal/service/admin_service_credentials_merge_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type updateAccountCredsRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account
|
||||
updateCalls int
|
||||
}
|
||||
|
||||
func (r *updateAccountCredsRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *updateAccountCredsRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateAccount_PreservesSensitiveCredsWhenIncomingOmits(t *testing.T) {
|
||||
accountID := int64(202)
|
||||
repo := &updateAccountCredsRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "rt-existing",
|
||||
"access_token": "at-existing",
|
||||
"id_token": "id-existing",
|
||||
"base_url": "https://old.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
// 模拟前端编辑:仅修改 base_url,没有传 token(脱敏后前端 spread 拿不到敏感键)
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Credentials: map[string]any{
|
||||
"base_url": "https://new.example.com",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
|
||||
// 敏感键应保留
|
||||
require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"])
|
||||
require.Equal(t, "at-existing", repo.account.Credentials["access_token"])
|
||||
require.Equal(t, "id-existing", repo.account.Credentials["id_token"])
|
||||
// 非敏感键被替换
|
||||
require.Equal(t, "https://new.example.com", repo.account.Credentials["base_url"])
|
||||
}
|
||||
|
||||
func TestUpdateAccount_ExplicitNewTokenOverwrites(t *testing.T) {
|
||||
accountID := int64(203)
|
||||
repo := &updateAccountCredsRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "rt-old",
|
||||
"api_key": "sk-old",
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "rt-new",
|
||||
// api_key 没传 → 应保留旧值
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
|
||||
require.Equal(t, "rt-new", repo.account.Credentials["refresh_token"])
|
||||
require.Equal(t, "sk-old", repo.account.Credentials["api_key"])
|
||||
}
|
||||
|
||||
func TestUpdateAccount_EmptyCredentialsSkipsUpdate(t *testing.T) {
|
||||
accountID := int64(204)
|
||||
repo := &updateAccountCredsRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"refresh_token": "rt-existing",
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
_, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Credentials: map[string]any{}, // len == 0 → 闸门跳过
|
||||
Name: "renamed",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"], "空 credentials 不应触碰已有 token")
|
||||
require.Equal(t, "renamed", repo.account.Name)
|
||||
}
|
||||
@ -2094,7 +2094,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
// 解析请求以获取 image_size(用于图片计费)
|
||||
imageSize := s.extractImageSize(body)
|
||||
imageInputSize := s.extractImageInputSize(body)
|
||||
imageSize := normalizeOpenAIImageSizeTier(imageInputSize)
|
||||
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent":
|
||||
@ -2465,6 +2466,7 @@ handleSuccess:
|
||||
ClientDisconnect: clientDisconnect,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
ImageInputSize: imageInputSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -4063,21 +4065,17 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
}
|
||||
|
||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
|
||||
func (s *AntigravityGatewayService) extractImageInputSize(body []byte) string {
|
||||
var req antigravity.GeminiRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return "2K" // 默认 2K
|
||||
return ""
|
||||
}
|
||||
|
||||
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||||
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||||
if size == "1K" || size == "2K" || size == "4K" {
|
||||
return size
|
||||
}
|
||||
return strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)
|
||||
}
|
||||
|
||||
return "2K" // 默认 2K
|
||||
return ""
|
||||
}
|
||||
|
||||
// isImageGenerationModel 判断模型是否为图片生成模型
|
||||
|
||||
@ -46,15 +46,15 @@ func TestExtractImageSize_ValidSizes(t *testing.T) {
|
||||
|
||||
// 1K
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
// 2K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
// 4K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
|
||||
@ -62,10 +62,10 @@ func TestExtractImageSize_CaseInsensitive(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
|
||||
@ -74,15 +74,15 @@ func TestExtractImageSize_Default(t *testing.T) {
|
||||
|
||||
// 无 generationConfig
|
||||
body := []byte(`{"contents":[]}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
// 有 generationConfig 但无 imageConfig
|
||||
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
// 有 imageConfig 但无 imageSize
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
|
||||
@ -90,10 +90,10 @@ func TestExtractImageSize_InvalidJSON(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`not valid json`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
body = []byte(`{"broken":`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
|
||||
@ -101,11 +101,11 @@ func TestExtractImageSize_EmptySize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
// 空格
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
|
||||
@ -113,11 +113,11 @@ func TestExtractImageSize_InvalidSize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
@ -18,7 +19,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
|
||||
switch signupSource {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc", "github", "google":
|
||||
case "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
|
||||
return signupSource
|
||||
default:
|
||||
return "email"
|
||||
@ -71,7 +72,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
|
||||
if err != nil {
|
||||
return nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
|
||||
return nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
return redeemCode, nil
|
||||
@ -109,7 +110,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
if s == nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
@ -118,18 +119,22 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
return nil, nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
slog.Error("oauth email register: policy rejected", "email", email, "error", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
|
||||
slog.Error("oauth email register: verify code failed", "email", email, "error", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
|
||||
slog.Error("oauth email register: invitation failed", "email", email, "error", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
slog.Error("oauth email register: ExistsByEmail failed", "email", email, "error", err.Error())
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
@ -158,6 +163,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return nil, nil, ErrEmailExists
|
||||
}
|
||||
slog.Error("oauth email register: userRepo.Create failed", "email", email, "signup_source", signupSource, "error", err.Error())
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
@ -181,7 +187,7 @@ func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
|
||||
if s == nil {
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
@ -358,6 +364,7 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit
|
||||
UsedAt: entity.UsedAt,
|
||||
Notes: oauthEmailFlowStringValue(entity.Notes),
|
||||
CreatedAt: entity.CreatedAt,
|
||||
ExpiresAt: entity.ExpiresAt,
|
||||
GroupID: entity.GroupID,
|
||||
ValidityDays: entity.ValidityDays,
|
||||
}, nil
|
||||
@ -368,7 +375,11 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit
|
||||
func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
|
||||
if client := s.oauthEmailFlowClient(ctx); client != nil {
|
||||
affected, err := client.RedeemCode.Update().
|
||||
Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
|
||||
Where(
|
||||
redeemcode.IDEQ(invitationID),
|
||||
redeemcode.StatusEQ(StatusUnused),
|
||||
redeemcode.Or(redeemcode.ExpiresAtIsNil(), redeemcode.ExpiresAtGT(time.Now().UTC())),
|
||||
).
|
||||
SetStatus(StatusUsed).
|
||||
SetUsedBy(userID).
|
||||
SetUsedAt(time.Now().UTC()).
|
||||
@ -396,6 +407,11 @@ func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, cod
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays)
|
||||
if code.ExpiresAt != nil {
|
||||
update = update.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
update = update.ClearExpiresAt()
|
||||
}
|
||||
if code.UsedBy != nil {
|
||||
update = update.SetUsedBy(*code.UsedBy)
|
||||
} else {
|
||||
|
||||
@ -157,7 +157,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
// 检查类型和状态
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
|
||||
return "", nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
@ -560,11 +560,25 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// canBypassRegistrationDisabledForOAuth 在钉钉企业模式(internal_only)且
|
||||
// dingtalk_connect_bypass_registration=true 时,允许跳过全局 registration_enabled 检查。
|
||||
func (s *AuthService) canBypassRegistrationDisabledForOAuth(ctx context.Context, signupSource string) bool {
|
||||
if signupSource != "dingtalk" {
|
||||
return false
|
||||
}
|
||||
cfg, err := s.settingService.GetDingTalkConnectOAuthConfig(ctx)
|
||||
if err != nil || !cfg.Enabled || !cfg.BypassRegistration {
|
||||
return false
|
||||
}
|
||||
return cfg.CorpRestrictionPolicy == "internal_only"
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
|
||||
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
|
||||
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
|
||||
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
|
||||
// signupSource 标识来源渠道("dingtalk"/"linuxdo"/"wechat"/"oidc" 等),仅用于豁免检查。
|
||||
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode, signupSource string) (*TokenPair, *User, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, nil, errors.New("refresh token cache not configured")
|
||||
@ -587,7 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
|
||||
return nil, nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
@ -601,7 +615,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
|
||||
if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
|
||||
return nil, nil, ErrInvitationCodeInvalid
|
||||
}
|
||||
invitationRedeemCode = redeemCode
|
||||
@ -617,7 +631,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
signupSource := inferLegacySignupSource(email)
|
||||
// 优先用 caller 显式传入的 signupSource(如 "dingtalk" / "linuxdo" / "oidc" / "wechat"),
|
||||
// 否则才按邮箱后缀推断——避免有真实邮箱的 OAuth 用户被推断为 "email" 渠道,导致渠道授权错读。
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
signupSource = inferLegacySignupSource(email)
|
||||
}
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
var defaultRPMLimit int
|
||||
if s.settingService != nil {
|
||||
@ -779,6 +797,8 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
|
||||
return defaults.GitHub, true
|
||||
case "google":
|
||||
return defaults.Google, true
|
||||
case "dingtalk":
|
||||
return defaults.DingTalk, true
|
||||
default:
|
||||
return ProviderDefaultGrantSettings{}, false
|
||||
}
|
||||
@ -992,6 +1012,8 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s
|
||||
func inferLegacySignupSource(email string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
switch {
|
||||
case strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain):
|
||||
return "dingtalk"
|
||||
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
|
||||
return "linuxdo"
|
||||
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
|
||||
@ -1086,7 +1108,8 @@ func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
|
||||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT access token
|
||||
|
||||
@ -602,7 +602,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, 9.5, user.Balance)
|
||||
require.Equal(t, 2, user.Concurrency)
|
||||
require.Equal(t, 5, user.Concurrency)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, int64(31), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 5, assigner.calls[0].ValidityDays)
|
||||
@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "", "linuxdo")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
|
||||
tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "", "linuxdo")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.Equal(t, existing.ID, user.ID)
|
||||
@ -667,3 +667,99 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
|
||||
require.Empty(t, repo.created)
|
||||
require.Empty(t, assigner.calls)
|
||||
}
|
||||
|
||||
// newAuthServiceWithDingTalkCfg 构建一个含完整 DingTalk config 的 AuthService,
|
||||
// 用于测试 canBypassRegistrationDisabledForOAuth。
|
||||
func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.DingTalkConnectConfig) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1},
|
||||
Default: config.DefaultConfig{UserBalance: 3.5, UserConcurrency: 2},
|
||||
DingTalk: dtCfg,
|
||||
}
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig(不设 Enabled/BypassRegistration/Policy)。
|
||||
func minDingTalkURLs() config.DingTalkConnectConfig {
|
||||
return config.DingTalkConnectConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
AuthorizeURL: "https://example.com/oauth2/auth",
|
||||
TokenURL: "https://example.com/oauth2/token",
|
||||
UserInfoURL: "https://example.com/oauth2/userinfo",
|
||||
RedirectURL: "https://example.com/callback",
|
||||
FrontendRedirectURL: "https://example.com/auth/callback",
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "internal",
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanBypassRegistrationDisabledForOAuth(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
signupSource string
|
||||
settings map[string]string
|
||||
dtCfg config.DingTalkConnectConfig
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-dingtalk source → false",
|
||||
signupSource: "linuxdo",
|
||||
settings: map[string]string{},
|
||||
dtCfg: minDingTalkURLs(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "dingtalk but cfg.Enabled=false → false",
|
||||
signupSource: "dingtalk",
|
||||
settings: map[string]string{
|
||||
SettingKeyDingTalkConnectEnabled: "false",
|
||||
SettingKeyDingTalkConnectBypassRegistration: "true",
|
||||
SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
|
||||
},
|
||||
dtCfg: minDingTalkURLs(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "dingtalk enabled but BypassRegistration=false → false",
|
||||
signupSource: "dingtalk",
|
||||
settings: map[string]string{
|
||||
SettingKeyDingTalkConnectEnabled: "true",
|
||||
SettingKeyDingTalkConnectBypassRegistration: "false",
|
||||
SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
|
||||
},
|
||||
dtCfg: minDingTalkURLs(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "dingtalk enabled + bypass=true but policy=none → false",
|
||||
signupSource: "dingtalk",
|
||||
settings: map[string]string{
|
||||
SettingKeyDingTalkConnectEnabled: "true",
|
||||
SettingKeyDingTalkConnectBypassRegistration: "true",
|
||||
SettingKeyDingTalkConnectCorpRestrictionPolicy: "none",
|
||||
},
|
||||
dtCfg: minDingTalkURLs(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "dingtalk enabled + bypass=true + policy=internal_only → true",
|
||||
signupSource: "dingtalk",
|
||||
settings: map[string]string{
|
||||
SettingKeyDingTalkConnectEnabled: "true",
|
||||
SettingKeyDingTalkConnectBypassRegistration: "true",
|
||||
SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
|
||||
},
|
||||
dtCfg: minDingTalkURLs(),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svc := newAuthServiceWithDingTalkCfg(tc.settings, tc.dtCfg)
|
||||
got := svc.canBypassRegistrationDisabledForOAuth(context.Background(), tc.signupSource)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
13
backend/internal/service/auth_service_test.go
Normal file
13
backend/internal/service/auth_service_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsReservedEmail_DingTalkDomain(t *testing.T) {
|
||||
require.True(t, isReservedEmail("dingtalk-123@dingtalk-connect.invalid"))
|
||||
require.True(t, isReservedEmail("DINGTALK-456@DINGTALK-CONNECT.INVALID")) // case-insensitive
|
||||
require.False(t, isReservedEmail("real@dingtalk.com"))
|
||||
}
|
||||
@ -809,6 +809,7 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
imageSize = NormalizeImageBillingTierOrDefault(imageSize)
|
||||
|
||||
// 获取单价
|
||||
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user