feat: rebuild auth identity foundation flow

This commit is contained in:
IanShaw027 2026-04-20 17:39:57 +08:00
parent fbd0a2e3c4
commit e9de839d87
123 changed files with 33599 additions and 772 deletions

266
backend/ent/authidentity.go Normal file
View File

@ -0,0 +1,266 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AuthIdentity is the model entity for the AuthIdentity schema.
type AuthIdentity struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// ProviderType holds the value of the "provider_type" field.
ProviderType string `json:"provider_type,omitempty"`
// ProviderKey holds the value of the "provider_key" field.
ProviderKey string `json:"provider_key,omitempty"`
// ProviderSubject holds the value of the "provider_subject" field.
ProviderSubject string `json:"provider_subject,omitempty"`
// VerifiedAt holds the value of the "verified_at" field.
VerifiedAt *time.Time `json:"verified_at,omitempty"`
// Issuer holds the value of the "issuer" field.
Issuer *string `json:"issuer,omitempty"`
// Metadata holds the value of the "metadata" field.
Metadata map[string]interface{} `json:"metadata,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the AuthIdentityQuery when eager-loading is set.
Edges AuthIdentityEdges `json:"edges"`
selectValues sql.SelectValues
}
// AuthIdentityEdges holds the relations/edges for other nodes in the graph.
type AuthIdentityEdges struct {
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// Channels holds the value of the channels edge.
Channels []*AuthIdentityChannel `json:"channels,omitempty"`
// AdoptionDecisions holds the value of the adoption_decisions edge.
AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [3]bool
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AuthIdentityEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// ChannelsOrErr returns the Channels value or an error if the edge
// was not loaded in eager-loading.
func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) {
if e.loadedTypes[1] {
return e.Channels, nil
}
return nil, &NotLoadedError{edge: "channels"}
}
// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge
// was not loaded in eager-loading.
func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) {
if e.loadedTypes[2] {
return e.AdoptionDecisions, nil
}
return nil, &NotLoadedError{edge: "adoption_decisions"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*AuthIdentity) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case authidentity.FieldMetadata:
values[i] = new([]byte)
case authidentity.FieldID, authidentity.FieldUserID:
values[i] = new(sql.NullInt64)
case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer:
values[i] = new(sql.NullString)
case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the AuthIdentity fields.
func (_m *AuthIdentity) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case authidentity.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case authidentity.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case authidentity.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case authidentity.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case authidentity.FieldProviderType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_type", values[i])
} else if value.Valid {
_m.ProviderType = value.String
}
case authidentity.FieldProviderKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_key", values[i])
} else if value.Valid {
_m.ProviderKey = value.String
}
case authidentity.FieldProviderSubject:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
} else if value.Valid {
_m.ProviderSubject = value.String
}
case authidentity.FieldVerifiedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field verified_at", values[i])
} else if value.Valid {
_m.VerifiedAt = new(time.Time)
*_m.VerifiedAt = value.Time
}
case authidentity.FieldIssuer:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field issuer", values[i])
} else if value.Valid {
_m.Issuer = new(string)
*_m.Issuer = value.String
}
case authidentity.FieldMetadata:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field metadata", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
return fmt.Errorf("unmarshal field metadata: %w", err)
}
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity.
// This includes values selected through modifiers, order, etc.
func (_m *AuthIdentity) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryUser queries the "user" edge of the AuthIdentity entity.
func (_m *AuthIdentity) QueryUser() *UserQuery {
return NewAuthIdentityClient(_m.config).QueryUser(_m)
}
// QueryChannels queries the "channels" edge of the AuthIdentity entity.
func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery {
return NewAuthIdentityClient(_m.config).QueryChannels(_m)
}
// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity.
func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m)
}
// Update returns a builder for updating this AuthIdentity.
// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne {
return NewAuthIdentityClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *AuthIdentity) Unwrap() *AuthIdentity {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: AuthIdentity is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *AuthIdentity) String() string {
var builder strings.Builder
builder.WriteString("AuthIdentity(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("provider_type=")
builder.WriteString(_m.ProviderType)
builder.WriteString(", ")
builder.WriteString("provider_key=")
builder.WriteString(_m.ProviderKey)
builder.WriteString(", ")
builder.WriteString("provider_subject=")
builder.WriteString(_m.ProviderSubject)
builder.WriteString(", ")
if v := _m.VerifiedAt; v != nil {
builder.WriteString("verified_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.Issuer; v != nil {
builder.WriteString("issuer=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("metadata=")
builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
builder.WriteByte(')')
return builder.String()
}
// AuthIdentities is a parsable slice of AuthIdentity.
type AuthIdentities []*AuthIdentity

View File

@ -0,0 +1,209 @@
// Code generated by ent, DO NOT EDIT.
package authidentity
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the authidentity type in the database.
Label = "auth_identity"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldProviderType holds the string denoting the provider_type field in the database.
FieldProviderType = "provider_type"
// FieldProviderKey holds the string denoting the provider_key field in the database.
FieldProviderKey = "provider_key"
// FieldProviderSubject holds the string denoting the provider_subject field in the database.
FieldProviderSubject = "provider_subject"
// FieldVerifiedAt holds the string denoting the verified_at field in the database.
FieldVerifiedAt = "verified_at"
// FieldIssuer holds the string denoting the issuer field in the database.
FieldIssuer = "issuer"
// FieldMetadata holds the string denoting the metadata field in the database.
FieldMetadata = "metadata"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeChannels holds the string denoting the channels edge name in mutations.
EdgeChannels = "channels"
// EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations.
EdgeAdoptionDecisions = "adoption_decisions"
// Table holds the table name of the authidentity in the database.
Table = "auth_identities"
// UserTable is the table that holds the user relation/edge.
UserTable = "auth_identities"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
// ChannelsTable is the table that holds the channels relation/edge.
ChannelsTable = "auth_identity_channels"
// ChannelsInverseTable is the table name for the AuthIdentityChannel entity.
// It exists in this package in order to avoid circular dependency with the "authidentitychannel" package.
ChannelsInverseTable = "auth_identity_channels"
// ChannelsColumn is the table column denoting the channels relation/edge.
ChannelsColumn = "identity_id"
// AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge.
AdoptionDecisionsTable = "identity_adoption_decisions"
// AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity.
// It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
AdoptionDecisionsInverseTable = "identity_adoption_decisions"
// AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge.
AdoptionDecisionsColumn = "identity_id"
)
// Columns holds all SQL columns for authidentity fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldUserID,
FieldProviderType,
FieldProviderKey,
FieldProviderSubject,
FieldVerifiedAt,
FieldIssuer,
FieldMetadata,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
ProviderTypeValidator func(string) error
// ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
ProviderKeyValidator func(string) error
// ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
ProviderSubjectValidator func(string) error
// DefaultMetadata holds the default value on creation for the "metadata" field.
DefaultMetadata func() map[string]interface{}
)
// OrderOption defines the ordering options for the AuthIdentity queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByProviderType orders the results by the provider_type field.
func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderType, opts...).ToFunc()
}
// ByProviderKey orders the results by the provider_key field.
func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
}
// ByProviderSubject orders the results by the provider_subject field.
func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
}
// ByVerifiedAt orders the results by the verified_at field.
func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc()
}
// ByIssuer orders the results by the issuer field.
func ByIssuer(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIssuer, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
// ByChannelsCount orders the results by channels count.
func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...)
}
}
// ByChannels orders the results by channels terms.
func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByAdoptionDecisionsCount orders the results by adoption_decisions count.
func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...)
}
}
// ByAdoptionDecisions orders the results by adoption_decisions terms.
func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}
func newChannelsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(ChannelsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
)
}
func newAdoptionDecisionsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AdoptionDecisionsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
)
}

View File

@ -0,0 +1,600 @@
// Code generated by ent, DO NOT EDIT.
package authidentity
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
}
// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
func ProviderType(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
}
// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
func ProviderKey(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
}
// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
func ProviderSubject(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
}
// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ.
func VerifiedAt(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
}
// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ.
func Issuer(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...))
}
// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
func ProviderTypeEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
}
// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
func ProviderTypeNEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v))
}
// ProviderTypeIn applies the In predicate on the "provider_type" field.
func ProviderTypeIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...))
}
// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...))
}
// ProviderTypeGT applies the GT predicate on the "provider_type" field.
func ProviderTypeGT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v))
}
// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
func ProviderTypeGTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v))
}
// ProviderTypeLT applies the LT predicate on the "provider_type" field.
func ProviderTypeLT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v))
}
// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
func ProviderTypeLTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v))
}
// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
func ProviderTypeContains(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v))
}
// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
func ProviderTypeHasPrefix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v))
}
// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
func ProviderTypeHasSuffix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v))
}
// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
func ProviderTypeEqualFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v))
}
// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
func ProviderTypeContainsFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v))
}
// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
func ProviderKeyEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
}
// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
func ProviderKeyNEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v))
}
// ProviderKeyIn applies the In predicate on the "provider_key" field.
func ProviderKeyIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...))
}
// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...))
}
// ProviderKeyGT applies the GT predicate on the "provider_key" field.
func ProviderKeyGT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v))
}
// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
func ProviderKeyGTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v))
}
// ProviderKeyLT applies the LT predicate on the "provider_key" field.
func ProviderKeyLT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v))
}
// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
func ProviderKeyLTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v))
}
// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
func ProviderKeyContains(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v))
}
// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
func ProviderKeyHasPrefix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v))
}
// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
func ProviderKeyHasSuffix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v))
}
// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
func ProviderKeyEqualFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v))
}
// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
func ProviderKeyContainsFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v))
}
// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
func ProviderSubjectEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
}
// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
func ProviderSubjectNEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v))
}
// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
func ProviderSubjectIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...))
}
// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...))
}
// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
func ProviderSubjectGT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v))
}
// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
func ProviderSubjectGTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v))
}
// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
func ProviderSubjectLT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v))
}
// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
func ProviderSubjectLTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v))
}
// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
func ProviderSubjectContains(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v))
}
// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v))
}
// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v))
}
// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
func ProviderSubjectEqualFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v))
}
// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
func ProviderSubjectContainsFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v))
}
// VerifiedAtEQ applies the EQ predicate on the "verified_at" field.
func VerifiedAtEQ(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
}
// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field.
func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v))
}
// VerifiedAtIn applies the In predicate on the "verified_at" field.
func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...))
}
// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field.
func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...))
}
// VerifiedAtGT applies the GT predicate on the "verified_at" field.
func VerifiedAtGT(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v))
}
// VerifiedAtGTE applies the GTE predicate on the "verified_at" field.
func VerifiedAtGTE(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v))
}
// VerifiedAtLT applies the LT predicate on the "verified_at" field.
func VerifiedAtLT(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v))
}
// VerifiedAtLTE applies the LTE predicate on the "verified_at" field.
func VerifiedAtLTE(v time.Time) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v))
}
// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field.
func VerifiedAtIsNil() predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt))
}
// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field.
func VerifiedAtNotNil() predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt))
}
// IssuerEQ applies the EQ predicate on the "issuer" field.
func IssuerEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
}
// IssuerNEQ applies the NEQ predicate on the "issuer" field.
func IssuerNEQ(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v))
}
// IssuerIn applies the In predicate on the "issuer" field.
func IssuerIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...))
}
// IssuerNotIn applies the NotIn predicate on the "issuer" field.
func IssuerNotIn(vs ...string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...))
}
// IssuerGT applies the GT predicate on the "issuer" field.
func IssuerGT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v))
}
// IssuerGTE applies the GTE predicate on the "issuer" field.
func IssuerGTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v))
}
// IssuerLT applies the LT predicate on the "issuer" field.
func IssuerLT(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v))
}
// IssuerLTE applies the LTE predicate on the "issuer" field.
func IssuerLTE(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v))
}
// IssuerContains applies the Contains predicate on the "issuer" field.
func IssuerContains(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v))
}
// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field.
func IssuerHasPrefix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v))
}
// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field.
func IssuerHasSuffix(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v))
}
// IssuerIsNil applies the IsNil predicate on the "issuer" field.
func IssuerIsNil() predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer))
}
// IssuerNotNil applies the NotNil predicate on the "issuer" field.
func IssuerNotNil() predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer))
}
// IssuerEqualFold applies the EqualFold predicate on the "issuer" field.
func IssuerEqualFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v))
}
// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field.
func IssuerContainsFold(v string) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.AuthIdentity {
return predicate.AuthIdentity(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.AuthIdentity {
return predicate.AuthIdentity(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasChannels applies the HasEdge predicate on the "channels" edge.
func HasChannels() predicate.AuthIdentity {
return predicate.AuthIdentity(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates).
func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity {
return predicate.AuthIdentity(func(s *sql.Selector) {
step := newChannelsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge.
func HasAdoptionDecisions() predicate.AuthIdentity {
return predicate.AuthIdentity(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates).
func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity {
return predicate.AuthIdentity(func(s *sql.Selector) {
step := newAdoptionDecisionsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.AuthIdentity) predicate.AuthIdentity {
return predicate.AuthIdentity(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AuthIdentityDelete is the builder for deleting a AuthIdentity entity.
type AuthIdentityDelete struct {
config
hooks []Hook
mutation *AuthIdentityMutation
}
// Where appends a list predicates to the AuthIdentityDelete builder.
func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity.
type AuthIdentityDeleteOne struct {
_d *AuthIdentityDelete
}
// Where appends a list predicates to the AuthIdentityDelete builder.
func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{authidentity.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,797 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AuthIdentityQuery is the builder for querying AuthIdentity entities.
type AuthIdentityQuery struct {
config
ctx *QueryContext
order []authidentity.OrderOption
inters []Interceptor
predicates []predicate.AuthIdentity
withUser *UserQuery
withChannels *AuthIdentityChannelQuery
withAdoptionDecisions *IdentityAdoptionDecisionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AuthIdentityQuery builder.
func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *AuthIdentityQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryChannels chains the current query on the "channels" edge.
func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery {
query := (&AuthIdentityChannelClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge.
func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first AuthIdentity entity from the query.
// Returns a *NotFoundError when no AuthIdentity was found.
func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{authidentity.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first AuthIdentity ID from the query.
// Returns a *NotFoundError when no AuthIdentity ID was found.
func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{authidentity.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one AuthIdentity entity is found.
// Returns a *NotFoundError when no AuthIdentity entities are found.
func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{authidentity.Label}
default:
return nil, &NotSingularError{authidentity.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only AuthIdentity ID in the query.
// Returns a *NotSingularError when more than one AuthIdentity ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{authidentity.Label}
default:
err = &NotSingularError{authidentity.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of AuthIdentities.
func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]()
return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of AuthIdentity IDs.
func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AuthIdentityQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery {
if _q == nil {
return nil
}
return &AuthIdentityQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]authidentity.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.AuthIdentity{}, _q.predicates...),
withUser: _q.withUser.Clone(),
withChannels: _q.withChannels.Clone(),
withAdoptionDecisions: _q.withAdoptionDecisions.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// WithChannels tells the query-builder to eager-load the nodes that are connected to
// the "channels" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery {
query := (&AuthIdentityChannelClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withChannels = query
return _q
}
// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to
// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery {
query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAdoptionDecisions = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.AuthIdentity.Query().
// GroupBy(authidentity.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AuthIdentityGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = authidentity.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.AuthIdentity.Query().
// Select(authidentity.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q}
sbuild.label = authidentity.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AuthIdentitySelect configured with the given aggregations.
func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !authidentity.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) {
var (
nodes = []*AuthIdentity{}
_spec = _q.querySpec()
loadedTypes = [3]bool{
_q.withUser != nil,
_q.withChannels != nil,
_q.withAdoptionDecisions != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*AuthIdentity).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &AuthIdentity{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
if query := _q.withChannels; query != nil {
if err := _q.loadChannels(ctx, query, nodes,
func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} },
func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil {
return nil, err
}
}
if query := _q.withAdoptionDecisions; query != nil {
if err := _q.loadAdoptionDecisions(ctx, query, nodes,
func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} },
func(n *AuthIdentity, e *IdentityAdoptionDecision) {
n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e)
}); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*AuthIdentity)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*AuthIdentity)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID)
}
query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.IdentityID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*AuthIdentity)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID)
}
query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.IdentityID
if fk == nil {
return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
for i := range fields {
if fields[i] != authidentity.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(authidentity.FieldUserID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(authidentity.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = authidentity.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities.
type AuthIdentityGroupBy struct {
selector
build *AuthIdentityQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities.
type AuthIdentitySelect struct {
*AuthIdentityQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v)
}
func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@ -0,0 +1,923 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// AuthIdentityUpdate is the builder for updating AuthIdentity entities.
type AuthIdentityUpdate struct {
config
hooks []Hook
mutation *AuthIdentityMutation
}
// Where appends a list predicates to the AuthIdentityUpdate builder.
func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetProviderType sets the "provider_type" field.
func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate {
_u.mutation.SetProviderType(v)
return _u
}
// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate {
if v != nil {
_u.SetProviderType(*v)
}
return _u
}
// SetProviderKey sets the "provider_key" field.
func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate {
_u.mutation.SetProviderKey(v)
return _u
}
// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate {
if v != nil {
_u.SetProviderKey(*v)
}
return _u
}
// SetProviderSubject sets the "provider_subject" field.
func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate {
_u.mutation.SetProviderSubject(v)
return _u
}
// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate {
if v != nil {
_u.SetProviderSubject(*v)
}
return _u
}
// SetVerifiedAt sets the "verified_at" field.
func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate {
_u.mutation.SetVerifiedAt(v)
return _u
}
// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate {
if v != nil {
_u.SetVerifiedAt(*v)
}
return _u
}
// ClearVerifiedAt clears the value of the "verified_at" field.
func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate {
_u.mutation.ClearVerifiedAt()
return _u
}
// SetIssuer sets the "issuer" field.
func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate {
_u.mutation.SetIssuer(v)
return _u
}
// SetNillableIssuer sets the "issuer" field if the given value is not nil.
func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate {
if v != nil {
_u.SetIssuer(*v)
}
return _u
}
// ClearIssuer clears the value of the "issuer" field.
func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate {
_u.mutation.ClearIssuer()
return _u
}
// SetMetadata sets the "metadata" field.
func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate {
_u.mutation.SetMetadata(v)
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate {
return _u.SetUserID(v.ID)
}
// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate {
_u.mutation.AddChannelIDs(ids...)
return _u
}
// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddChannelIDs(ids...)
}
// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
_u.mutation.AddAdoptionDecisionIDs(ids...)
return _u
}
// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAdoptionDecisionIDs(ids...)
}
// Mutation returns the AuthIdentityMutation object of the builder.
func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate {
_u.mutation.ClearUser()
return _u
}
// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate {
_u.mutation.ClearChannels()
return _u
}
// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate {
_u.mutation.RemoveChannelIDs(ids...)
return _u
}
// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveChannelIDs(ids...)
}
// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate {
_u.mutation.ClearAdoptionDecisions()
return _u
}
// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
_u.mutation.RemoveAdoptionDecisionIDs(ids...)
return _u
}
// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAdoptionDecisionIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AuthIdentityUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := authidentity.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AuthIdentityUpdate) check() error {
if v, ok := _u.mutation.ProviderType(); ok {
if err := authidentity.ProviderTypeValidator(v); err != nil {
return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
}
}
if v, ok := _u.mutation.ProviderKey(); ok {
if err := authidentity.ProviderKeyValidator(v); err != nil {
return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
}
}
if v, ok := _u.mutation.ProviderSubject(); ok {
if err := authidentity.ProviderSubjectValidator(v); err != nil {
return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
}
return nil
}
func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ProviderType(); ok {
_spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
}
if value, ok := _u.mutation.ProviderKey(); ok {
_spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
}
if value, ok := _u.mutation.ProviderSubject(); ok {
_spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
}
if value, ok := _u.mutation.VerifiedAt(); ok {
_spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
}
if _u.mutation.VerifiedAtCleared() {
_spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
}
if value, ok := _u.mutation.Issuer(); ok {
_spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
}
if _u.mutation.IssuerCleared() {
_spec.ClearField(authidentity.FieldIssuer, field.TypeString)
}
if value, ok := _u.mutation.Metadata(); ok {
_spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentity.UserTable,
Columns: []string{authidentity.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentity.UserTable,
Columns: []string{authidentity.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.ChannelsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.ChannelsTable,
Columns: []string{authidentity.ChannelsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.ChannelsTable,
Columns: []string{authidentity.ChannelsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.ChannelsTable,
Columns: []string{authidentity.ChannelsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AdoptionDecisionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.AdoptionDecisionsTable,
Columns: []string{authidentity.AdoptionDecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.AdoptionDecisionsTable,
Columns: []string{authidentity.AdoptionDecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.AdoptionDecisionsTable,
Columns: []string{authidentity.AdoptionDecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authidentity.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity.
type AuthIdentityUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AuthIdentityMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetUserID sets the "user_id" field.
func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetProviderType sets the "provider_type" field.
func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne {
_u.mutation.SetProviderType(v)
return _u
}
// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne {
if v != nil {
_u.SetProviderType(*v)
}
return _u
}
// SetProviderKey sets the "provider_key" field.
func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne {
_u.mutation.SetProviderKey(v)
return _u
}
// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne {
if v != nil {
_u.SetProviderKey(*v)
}
return _u
}
// SetProviderSubject sets the "provider_subject" field.
func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne {
_u.mutation.SetProviderSubject(v)
return _u
}
// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne {
if v != nil {
_u.SetProviderSubject(*v)
}
return _u
}
// SetVerifiedAt sets the "verified_at" field.
func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne {
_u.mutation.SetVerifiedAt(v)
return _u
}
// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne {
if v != nil {
_u.SetVerifiedAt(*v)
}
return _u
}
// ClearVerifiedAt clears the value of the "verified_at" field.
func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne {
_u.mutation.ClearVerifiedAt()
return _u
}
// SetIssuer sets the "issuer" field.
func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne {
_u.mutation.SetIssuer(v)
return _u
}
// SetNillableIssuer sets the "issuer" field if the given value is not nil.
func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne {
if v != nil {
_u.SetIssuer(*v)
}
return _u
}
// ClearIssuer clears the value of the "issuer" field.
func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne {
_u.mutation.ClearIssuer()
return _u
}
// SetMetadata sets the "metadata" field.
func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne {
_u.mutation.SetMetadata(v)
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne {
return _u.SetUserID(v.ID)
}
// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
_u.mutation.AddChannelIDs(ids...)
return _u
}
// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddChannelIDs(ids...)
}
// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
_u.mutation.AddAdoptionDecisionIDs(ids...)
return _u
}
// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAdoptionDecisionIDs(ids...)
}
// Mutation returns the AuthIdentityMutation object of the builder.
func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne {
_u.mutation.ClearUser()
return _u
}
// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne {
_u.mutation.ClearChannels()
return _u
}
// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
_u.mutation.RemoveChannelIDs(ids...)
return _u
}
// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveChannelIDs(ids...)
}
// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne {
_u.mutation.ClearAdoptionDecisions()
return _u
}
// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
_u.mutation.RemoveAdoptionDecisionIDs(ids...)
return _u
}
// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAdoptionDecisionIDs(ids...)
}
// Where appends a list predicates to the AuthIdentityUpdate builder.
func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated AuthIdentity entity.
func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AuthIdentityUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := authidentity.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AuthIdentityUpdateOne) check() error {
if v, ok := _u.mutation.ProviderType(); ok {
if err := authidentity.ProviderTypeValidator(v); err != nil {
return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
}
}
if v, ok := _u.mutation.ProviderKey(); ok {
if err := authidentity.ProviderKeyValidator(v); err != nil {
return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
}
}
if v, ok := _u.mutation.ProviderSubject(); ok {
if err := authidentity.ProviderSubjectValidator(v); err != nil {
return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
}
return nil
}
func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
for _, f := range fields {
if !authidentity.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != authidentity.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ProviderType(); ok {
_spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
}
if value, ok := _u.mutation.ProviderKey(); ok {
_spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
}
if value, ok := _u.mutation.ProviderSubject(); ok {
_spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
}
if value, ok := _u.mutation.VerifiedAt(); ok {
_spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
}
if _u.mutation.VerifiedAtCleared() {
_spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
}
if value, ok := _u.mutation.Issuer(); ok {
_spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
}
if _u.mutation.IssuerCleared() {
_spec.ClearField(authidentity.FieldIssuer, field.TypeString)
}
if value, ok := _u.mutation.Metadata(); ok {
_spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentity.UserTable,
Columns: []string{authidentity.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentity.UserTable,
Columns: []string{authidentity.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.ChannelsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.ChannelsTable,
Columns: []string{authidentity.ChannelsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.ChannelsTable,
Columns: []string{authidentity.ChannelsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.ChannelsTable,
Columns: []string{authidentity.ChannelsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AdoptionDecisionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.AdoptionDecisionsTable,
Columns: []string{authidentity.AdoptionDecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.AdoptionDecisionsTable,
Columns: []string{authidentity.AdoptionDecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: authidentity.AdoptionDecisionsTable,
Columns: []string{authidentity.AdoptionDecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &AuthIdentity{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authidentity.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@ -0,0 +1,228 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
)
// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema.
type AuthIdentityChannel struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// IdentityID holds the value of the "identity_id" field.
IdentityID int64 `json:"identity_id,omitempty"`
// ProviderType holds the value of the "provider_type" field.
ProviderType string `json:"provider_type,omitempty"`
// ProviderKey holds the value of the "provider_key" field.
ProviderKey string `json:"provider_key,omitempty"`
// Channel holds the value of the "channel" field.
Channel string `json:"channel,omitempty"`
// ChannelAppID holds the value of the "channel_app_id" field.
ChannelAppID string `json:"channel_app_id,omitempty"`
// ChannelSubject holds the value of the "channel_subject" field.
ChannelSubject string `json:"channel_subject,omitempty"`
// Metadata holds the value of the "metadata" field.
Metadata map[string]interface{} `json:"metadata,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the AuthIdentityChannelQuery when eager-loading is set.
Edges AuthIdentityChannelEdges `json:"edges"`
selectValues sql.SelectValues
}
// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph.
type AuthIdentityChannelEdges struct {
// Identity holds the value of the identity edge.
Identity *AuthIdentity `json:"identity,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// IdentityOrErr returns the Identity value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) {
if e.Identity != nil {
return e.Identity, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: authidentity.Label}
}
return nil, &NotLoadedError{edge: "identity"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case authidentitychannel.FieldMetadata:
values[i] = new([]byte)
case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID:
values[i] = new(sql.NullInt64)
case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject:
values[i] = new(sql.NullString)
case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the AuthIdentityChannel fields.
func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case authidentitychannel.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case authidentitychannel.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case authidentitychannel.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case authidentitychannel.FieldIdentityID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field identity_id", values[i])
} else if value.Valid {
_m.IdentityID = value.Int64
}
case authidentitychannel.FieldProviderType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_type", values[i])
} else if value.Valid {
_m.ProviderType = value.String
}
case authidentitychannel.FieldProviderKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_key", values[i])
} else if value.Valid {
_m.ProviderKey = value.String
}
case authidentitychannel.FieldChannel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field channel", values[i])
} else if value.Valid {
_m.Channel = value.String
}
case authidentitychannel.FieldChannelAppID:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field channel_app_id", values[i])
} else if value.Valid {
_m.ChannelAppID = value.String
}
case authidentitychannel.FieldChannelSubject:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field channel_subject", values[i])
} else if value.Valid {
_m.ChannelSubject = value.String
}
case authidentitychannel.FieldMetadata:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field metadata", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
return fmt.Errorf("unmarshal field metadata: %w", err)
}
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel.
// This includes values selected through modifiers, order, etc.
func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity.
func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery {
return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m)
}
// Update returns a builder for updating this AuthIdentityChannel.
// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne {
return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: AuthIdentityChannel is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *AuthIdentityChannel) String() string {
var builder strings.Builder
builder.WriteString("AuthIdentityChannel(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("identity_id=")
builder.WriteString(fmt.Sprintf("%v", _m.IdentityID))
builder.WriteString(", ")
builder.WriteString("provider_type=")
builder.WriteString(_m.ProviderType)
builder.WriteString(", ")
builder.WriteString("provider_key=")
builder.WriteString(_m.ProviderKey)
builder.WriteString(", ")
builder.WriteString("channel=")
builder.WriteString(_m.Channel)
builder.WriteString(", ")
builder.WriteString("channel_app_id=")
builder.WriteString(_m.ChannelAppID)
builder.WriteString(", ")
builder.WriteString("channel_subject=")
builder.WriteString(_m.ChannelSubject)
builder.WriteString(", ")
builder.WriteString("metadata=")
builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
builder.WriteByte(')')
return builder.String()
}
// AuthIdentityChannels is a parsable slice of AuthIdentityChannel.
type AuthIdentityChannels []*AuthIdentityChannel

View File

@ -0,0 +1,153 @@
// Code generated by ent, DO NOT EDIT.
package authidentitychannel
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the authidentitychannel type in the database.
Label = "auth_identity_channel"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldIdentityID holds the string denoting the identity_id field in the database.
FieldIdentityID = "identity_id"
// FieldProviderType holds the string denoting the provider_type field in the database.
FieldProviderType = "provider_type"
// FieldProviderKey holds the string denoting the provider_key field in the database.
FieldProviderKey = "provider_key"
// FieldChannel holds the string denoting the channel field in the database.
FieldChannel = "channel"
// FieldChannelAppID holds the string denoting the channel_app_id field in the database.
FieldChannelAppID = "channel_app_id"
// FieldChannelSubject holds the string denoting the channel_subject field in the database.
FieldChannelSubject = "channel_subject"
// FieldMetadata holds the string denoting the metadata field in the database.
FieldMetadata = "metadata"
// EdgeIdentity holds the string denoting the identity edge name in mutations.
EdgeIdentity = "identity"
// Table holds the table name of the authidentitychannel in the database.
Table = "auth_identity_channels"
// IdentityTable is the table that holds the identity relation/edge.
IdentityTable = "auth_identity_channels"
// IdentityInverseTable is the table name for the AuthIdentity entity.
// It exists in this package in order to avoid circular dependency with the "authidentity" package.
IdentityInverseTable = "auth_identities"
// IdentityColumn is the table column denoting the identity relation/edge.
IdentityColumn = "identity_id"
)
// Columns holds all SQL columns for authidentitychannel fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldIdentityID,
FieldProviderType,
FieldProviderKey,
FieldChannel,
FieldChannelAppID,
FieldChannelSubject,
FieldMetadata,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
ProviderTypeValidator func(string) error
// ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
ProviderKeyValidator func(string) error
// ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
ChannelValidator func(string) error
// ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
ChannelAppIDValidator func(string) error
// ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
ChannelSubjectValidator func(string) error
// DefaultMetadata holds the default value on creation for the "metadata" field.
DefaultMetadata func() map[string]interface{}
)
// OrderOption defines the ordering options for the AuthIdentityChannel queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByIdentityID orders the results by the identity_id field.
func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
}
// ByProviderType orders the results by the provider_type field.
func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderType, opts...).ToFunc()
}
// ByProviderKey orders the results by the provider_key field.
func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
}
// ByChannel orders the results by the channel field.
func ByChannel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldChannel, opts...).ToFunc()
}
// ByChannelAppID orders the results by the channel_app_id field.
func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldChannelAppID, opts...).ToFunc()
}
// ByChannelSubject orders the results by the channel_subject field.
func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldChannelSubject, opts...).ToFunc()
}
// ByIdentityField orders the results by identity field.
func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
}
}
func newIdentityStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(IdentityInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
)
}

View File

@ -0,0 +1,559 @@
// Code generated by ent, DO NOT EDIT.
package authidentitychannel
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
}
// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
func IdentityID(v int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
}
// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
func ProviderType(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
}
// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
func ProviderKey(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
}
// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ.
func Channel(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
}
// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ.
func ChannelAppID(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
}
// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ.
func ChannelSubject(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v))
}
// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
func IdentityIDEQ(v int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
}
// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v))
}
// IdentityIDIn applies the In predicate on the "identity_id" field.
func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...))
}
// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...))
}
// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
func ProviderTypeEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
}
// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v))
}
// ProviderTypeIn applies the In predicate on the "provider_type" field.
func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...))
}
// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...))
}
// ProviderTypeGT applies the GT predicate on the "provider_type" field.
func ProviderTypeGT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v))
}
// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
func ProviderTypeGTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v))
}
// ProviderTypeLT applies the LT predicate on the "provider_type" field.
func ProviderTypeLT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v))
}
// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
func ProviderTypeLTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v))
}
// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
func ProviderTypeContains(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v))
}
// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v))
}
// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v))
}
// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v))
}
// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v))
}
// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
func ProviderKeyEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
}
// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v))
}
// ProviderKeyIn applies the In predicate on the "provider_key" field.
func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...))
}
// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...))
}
// ProviderKeyGT applies the GT predicate on the "provider_key" field.
func ProviderKeyGT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v))
}
// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
func ProviderKeyGTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v))
}
// ProviderKeyLT applies the LT predicate on the "provider_key" field.
func ProviderKeyLT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v))
}
// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
func ProviderKeyLTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v))
}
// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
func ProviderKeyContains(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v))
}
// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v))
}
// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v))
}
// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v))
}
// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v))
}
// ChannelEQ applies the EQ predicate on the "channel" field.
func ChannelEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
}
// ChannelNEQ applies the NEQ predicate on the "channel" field.
func ChannelNEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v))
}
// ChannelIn applies the In predicate on the "channel" field.
func ChannelIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...))
}
// ChannelNotIn applies the NotIn predicate on the "channel" field.
func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...))
}
// ChannelGT applies the GT predicate on the "channel" field.
func ChannelGT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v))
}
// ChannelGTE applies the GTE predicate on the "channel" field.
func ChannelGTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v))
}
// ChannelLT applies the LT predicate on the "channel" field.
func ChannelLT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v))
}
// ChannelLTE applies the LTE predicate on the "channel" field.
func ChannelLTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v))
}
// ChannelContains applies the Contains predicate on the "channel" field.
func ChannelContains(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v))
}
// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field.
func ChannelHasPrefix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v))
}
// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field.
func ChannelHasSuffix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v))
}
// ChannelEqualFold applies the EqualFold predicate on the "channel" field.
func ChannelEqualFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v))
}
// ChannelContainsFold applies the ContainsFold predicate on the "channel" field.
func ChannelContainsFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v))
}
// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field.
func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
}
// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field.
func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v))
}
// ChannelAppIDIn applies the In predicate on the "channel_app_id" field.
func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...))
}
// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field.
func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...))
}
// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field.
func ChannelAppIDGT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v))
}
// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field.
func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v))
}
// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field.
func ChannelAppIDLT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v))
}
// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field.
func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v))
}
// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field.
func ChannelAppIDContains(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v))
}
// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field.
func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v))
}
// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field.
func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v))
}
// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field.
func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v))
}
// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field.
func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v))
}
// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field.
func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
}
// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field.
func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v))
}
// ChannelSubjectIn applies the In predicate on the "channel_subject" field.
func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...))
}
// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field.
func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...))
}
// ChannelSubjectGT applies the GT predicate on the "channel_subject" field.
func ChannelSubjectGT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v))
}
// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field.
func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v))
}
// ChannelSubjectLT applies the LT predicate on the "channel_subject" field.
func ChannelSubjectLT(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v))
}
// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field.
func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v))
}
// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field.
func ChannelSubjectContains(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v))
}
// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field.
func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v))
}
// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field.
func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v))
}
// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field.
func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v))
}
// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field.
func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v))
}
// HasIdentity applies the HasEdge predicate on the "identity" edge.
func HasIdentity() predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(func(s *sql.Selector) {
step := newIdentityStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
return predicate.AuthIdentityChannel(sql.NotPredicates(p))
}

View File

@ -0,0 +1,932 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
)
// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity.
type AuthIdentityChannelCreate struct {
config
mutation *AuthIdentityChannelMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetIdentityID sets the "identity_id" field.
func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate {
_c.mutation.SetIdentityID(v)
return _c
}
// SetProviderType sets the "provider_type" field.
func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate {
_c.mutation.SetProviderType(v)
return _c
}
// SetProviderKey sets the "provider_key" field.
func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate {
_c.mutation.SetProviderKey(v)
return _c
}
// SetChannel sets the "channel" field.
func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate {
_c.mutation.SetChannel(v)
return _c
}
// SetChannelAppID sets the "channel_app_id" field.
func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate {
_c.mutation.SetChannelAppID(v)
return _c
}
// SetChannelSubject sets the "channel_subject" field.
func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate {
_c.mutation.SetChannelSubject(v)
return _c
}
// SetMetadata sets the "metadata" field.
func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate {
_c.mutation.SetMetadata(v)
return _c
}
// SetIdentity sets the "identity" edge to the AuthIdentity entity.
func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate {
return _c.SetIdentityID(v.ID)
}
// Mutation returns the AuthIdentityChannelMutation object of the builder.
func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation {
return _c.mutation
}
// Save creates the AuthIdentityChannel in the database.
func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *AuthIdentityChannelCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := authidentitychannel.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := authidentitychannel.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.Metadata(); !ok {
v := authidentitychannel.DefaultMetadata()
_c.mutation.SetMetadata(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *AuthIdentityChannelCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)}
}
if _, ok := _c.mutation.IdentityID(); !ok {
return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)}
}
if _, ok := _c.mutation.ProviderType(); !ok {
return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)}
}
if v, ok := _c.mutation.ProviderType(); ok {
if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
}
}
if _, ok := _c.mutation.ProviderKey(); !ok {
return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)}
}
if v, ok := _c.mutation.ProviderKey(); ok {
if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
}
}
if _, ok := _c.mutation.Channel(); !ok {
return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)}
}
if v, ok := _c.mutation.Channel(); ok {
if err := authidentitychannel.ChannelValidator(v); err != nil {
return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
}
}
if _, ok := _c.mutation.ChannelAppID(); !ok {
return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)}
}
if v, ok := _c.mutation.ChannelAppID(); ok {
if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
}
}
if _, ok := _c.mutation.ChannelSubject(); !ok {
return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)}
}
if v, ok := _c.mutation.ChannelSubject(); ok {
if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
}
}
if _, ok := _c.mutation.Metadata(); !ok {
return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)}
}
if len(_c.mutation.IdentityIDs()) == 0 {
return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)}
}
return nil
}
func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) {
var (
_node = &AuthIdentityChannel{config: _c.config}
_spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.ProviderType(); ok {
_spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
_node.ProviderType = value
}
if value, ok := _c.mutation.ProviderKey(); ok {
_spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
_node.ProviderKey = value
}
if value, ok := _c.mutation.Channel(); ok {
_spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
_node.Channel = value
}
if value, ok := _c.mutation.ChannelAppID(); ok {
_spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
_node.ChannelAppID = value
}
if value, ok := _c.mutation.ChannelSubject(); ok {
_spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
_node.ChannelSubject = value
}
if value, ok := _c.mutation.Metadata(); ok {
_spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
_node.Metadata = value
}
if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentitychannel.IdentityTable,
Columns: []string{authidentitychannel.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.IdentityID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.AuthIdentityChannel.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.AuthIdentityChannelUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne {
_c.conflict = opts
return &AuthIdentityChannelUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.AuthIdentityChannel.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &AuthIdentityChannelUpsertOne{
create: _c,
}
}
type (
// AuthIdentityChannelUpsertOne is the builder for "upsert"-ing
// one AuthIdentityChannel node.
AuthIdentityChannelUpsertOne struct {
create *AuthIdentityChannelCreate
}
// AuthIdentityChannelUpsert is the "OnConflict" setter.
AuthIdentityChannelUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldUpdatedAt)
return u
}
// SetIdentityID sets the "identity_id" field.
func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldIdentityID, v)
return u
}
// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldIdentityID)
return u
}
// SetProviderType sets the "provider_type" field.
func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldProviderType, v)
return u
}
// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldProviderType)
return u
}
// SetProviderKey sets the "provider_key" field.
func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldProviderKey, v)
return u
}
// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldProviderKey)
return u
}
// SetChannel sets the "channel" field.
func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldChannel, v)
return u
}
// UpdateChannel sets the "channel" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldChannel)
return u
}
// SetChannelAppID sets the "channel_app_id" field.
func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldChannelAppID, v)
return u
}
// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldChannelAppID)
return u
}
// SetChannelSubject sets the "channel_subject" field.
func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldChannelSubject, v)
return u
}
// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldChannelSubject)
return u
}
// SetMetadata sets the "metadata" field.
func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert {
u.Set(authidentitychannel.FieldMetadata, v)
return u
}
// UpdateMetadata sets the "metadata" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert {
u.SetExcluded(authidentitychannel.FieldMetadata)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.AuthIdentityChannel.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(authidentitychannel.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.AuthIdentityChannel.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict
// documentation for more info.
func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&AuthIdentityChannelUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateUpdatedAt()
})
}
// SetIdentityID sets the "identity_id" field.
func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetIdentityID(v)
})
}
// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateIdentityID()
})
}
// SetProviderType sets the "provider_type" field.
func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetProviderType(v)
})
}
// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateProviderType()
})
}
// SetProviderKey sets the "provider_key" field.
func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetProviderKey(v)
})
}
// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateProviderKey()
})
}
// SetChannel sets the "channel" field.
func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetChannel(v)
})
}
// UpdateChannel sets the "channel" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateChannel()
})
}
// SetChannelAppID sets the "channel_app_id" field.
func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetChannelAppID(v)
})
}
// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateChannelAppID()
})
}
// SetChannelSubject sets the "channel_subject" field.
func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetChannelSubject(v)
})
}
// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateChannelSubject()
})
}
// SetMetadata sets the "metadata" field.
func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetMetadata(v)
})
}
// UpdateMetadata sets the "metadata" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateMetadata()
})
}
// Exec executes the query.
func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk.
type AuthIdentityChannelCreateBulk struct {
config
err error
builders []*AuthIdentityChannelCreate
conflict []sql.ConflictOption
}
// Save creates the AuthIdentityChannel entities in the database.
func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*AuthIdentityChannel, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*AuthIdentityChannelMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.AuthIdentityChannel.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.AuthIdentityChannelUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk {
_c.conflict = opts
return &AuthIdentityChannelUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.AuthIdentityChannel.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &AuthIdentityChannelUpsertBulk{
create: _c,
}
}
// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing
// a bulk of AuthIdentityChannel nodes.
type AuthIdentityChannelUpsertBulk struct {
create *AuthIdentityChannelCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.AuthIdentityChannel.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(authidentitychannel.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.AuthIdentityChannel.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict
// documentation for more info.
func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&AuthIdentityChannelUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateUpdatedAt()
})
}
// SetIdentityID sets the "identity_id" field.
func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetIdentityID(v)
})
}
// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateIdentityID()
})
}
// SetProviderType sets the "provider_type" field.
func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetProviderType(v)
})
}
// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateProviderType()
})
}
// SetProviderKey sets the "provider_key" field.
func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetProviderKey(v)
})
}
// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateProviderKey()
})
}
// SetChannel sets the "channel" field.
func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetChannel(v)
})
}
// UpdateChannel sets the "channel" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateChannel()
})
}
// SetChannelAppID sets the "channel_app_id" field.
func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetChannelAppID(v)
})
}
// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateChannelAppID()
})
}
// SetChannelSubject sets the "channel_subject" field.
func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetChannelSubject(v)
})
}
// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateChannelSubject()
})
}
// SetMetadata sets the "metadata" field.
func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.SetMetadata(v)
})
}
// UpdateMetadata sets the "metadata" field to the value that was provided on create.
func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk {
return u.Update(func(s *AuthIdentityChannelUpsert) {
s.UpdateMetadata()
})
}
// Exec executes the query.
func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity.
type AuthIdentityChannelDelete struct {
config
hooks []Hook
mutation *AuthIdentityChannelMutation
}
// Where appends a list predicates to the AuthIdentityChannelDelete builder.
func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity.
type AuthIdentityChannelDeleteOne struct {
_d *AuthIdentityChannelDelete
}
// Where appends a list predicates to the AuthIdentityChannelDelete builder.
func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{authidentitychannel.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,643 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities.
type AuthIdentityChannelQuery struct {
config
ctx *QueryContext
order []authidentitychannel.OrderOption
inters []Interceptor
predicates []predicate.AuthIdentityChannel
withIdentity *AuthIdentityQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AuthIdentityChannelQuery builder.
func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryIdentity chains the current query on the "identity" edge.
func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first AuthIdentityChannel entity from the query.
// Returns a *NotFoundError when no AuthIdentityChannel was found.
func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{authidentitychannel.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first AuthIdentityChannel ID from the query.
// Returns a *NotFoundError when no AuthIdentityChannel ID was found.
func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{authidentitychannel.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found.
// Returns a *NotFoundError when no AuthIdentityChannel entities are found.
func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{authidentitychannel.Label}
default:
return nil, &NotSingularError{authidentitychannel.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query.
// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{authidentitychannel.Label}
default:
err = &NotSingularError{authidentitychannel.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of AuthIdentityChannels.
func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]()
return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of AuthIdentityChannel IDs.
func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery {
if _q == nil {
return nil
}
return &AuthIdentityChannelQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]authidentitychannel.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...),
withIdentity: _q.withIdentity.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithIdentity tells the query-builder to eager-load the nodes that are connected to
// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withIdentity = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.AuthIdentityChannel.Query().
// GroupBy(authidentitychannel.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AuthIdentityChannelGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = authidentitychannel.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.AuthIdentityChannel.Query().
// Select(authidentitychannel.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q}
sbuild.label = authidentitychannel.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations.
func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !authidentitychannel.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) {
var (
nodes = []*AuthIdentityChannel{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withIdentity != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*AuthIdentityChannel).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &AuthIdentityChannel{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withIdentity; query != nil {
if err := _q.loadIdentity(ctx, query, nodes, nil,
func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*AuthIdentityChannel)
for i := range nodes {
fk := nodes[i].IdentityID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(authidentity.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
for i := range fields {
if fields[i] != authidentitychannel.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withIdentity != nil {
_spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(authidentitychannel.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = authidentitychannel.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities.
type AuthIdentityChannelGroupBy struct {
selector
build *AuthIdentityChannelQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities.
type AuthIdentityChannelSelect struct {
*AuthIdentityChannelQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v)
}
func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@ -0,0 +1,581 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities.
type AuthIdentityChannelUpdate struct {
config
hooks []Hook
mutation *AuthIdentityChannelMutation
}
// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetIdentityID sets the "identity_id" field.
func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate {
_u.mutation.SetIdentityID(v)
return _u
}
// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate {
if v != nil {
_u.SetIdentityID(*v)
}
return _u
}
// SetProviderType sets the "provider_type" field.
func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate {
_u.mutation.SetProviderType(v)
return _u
}
// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate {
if v != nil {
_u.SetProviderType(*v)
}
return _u
}
// SetProviderKey sets the "provider_key" field.
func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate {
_u.mutation.SetProviderKey(v)
return _u
}
// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate {
if v != nil {
_u.SetProviderKey(*v)
}
return _u
}
// SetChannel sets the "channel" field.
func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate {
_u.mutation.SetChannel(v)
return _u
}
// SetNillableChannel sets the "channel" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate {
if v != nil {
_u.SetChannel(*v)
}
return _u
}
// SetChannelAppID sets the "channel_app_id" field.
func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate {
_u.mutation.SetChannelAppID(v)
return _u
}
// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate {
if v != nil {
_u.SetChannelAppID(*v)
}
return _u
}
// SetChannelSubject sets the "channel_subject" field.
func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate {
_u.mutation.SetChannelSubject(v)
return _u
}
// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate {
if v != nil {
_u.SetChannelSubject(*v)
}
return _u
}
// SetMetadata sets the "metadata" field.
func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate {
_u.mutation.SetMetadata(v)
return _u
}
// SetIdentity sets the "identity" edge to the AuthIdentity entity.
func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate {
return _u.SetIdentityID(v.ID)
}
// Mutation returns the AuthIdentityChannelMutation object of the builder.
func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation {
return _u.mutation
}
// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate {
_u.mutation.ClearIdentity()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AuthIdentityChannelUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := authidentitychannel.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AuthIdentityChannelUpdate) check() error {
if v, ok := _u.mutation.ProviderType(); ok {
if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
}
}
if v, ok := _u.mutation.ProviderKey(); ok {
if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
}
}
if v, ok := _u.mutation.Channel(); ok {
if err := authidentitychannel.ChannelValidator(v); err != nil {
return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
}
}
if v, ok := _u.mutation.ChannelAppID(); ok {
if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
}
}
if v, ok := _u.mutation.ChannelSubject(); ok {
if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
}
}
if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
}
return nil
}
func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ProviderType(); ok {
_spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
}
if value, ok := _u.mutation.ProviderKey(); ok {
_spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
}
if value, ok := _u.mutation.Channel(); ok {
_spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
}
if value, ok := _u.mutation.ChannelAppID(); ok {
_spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
}
if value, ok := _u.mutation.ChannelSubject(); ok {
_spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
}
if value, ok := _u.mutation.Metadata(); ok {
_spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
}
if _u.mutation.IdentityCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentitychannel.IdentityTable,
Columns: []string{authidentitychannel.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentitychannel.IdentityTable,
Columns: []string{authidentitychannel.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authidentitychannel.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity.
type AuthIdentityChannelUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AuthIdentityChannelMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetIdentityID sets the "identity_id" field.
func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne {
_u.mutation.SetIdentityID(v)
return _u
}
// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne {
if v != nil {
_u.SetIdentityID(*v)
}
return _u
}
// SetProviderType sets the "provider_type" field.
func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne {
_u.mutation.SetProviderType(v)
return _u
}
// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne {
if v != nil {
_u.SetProviderType(*v)
}
return _u
}
// SetProviderKey sets the "provider_key" field.
func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne {
_u.mutation.SetProviderKey(v)
return _u
}
// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne {
if v != nil {
_u.SetProviderKey(*v)
}
return _u
}
// SetChannel sets the "channel" field.
func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne {
_u.mutation.SetChannel(v)
return _u
}
// SetNillableChannel sets the "channel" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne {
if v != nil {
_u.SetChannel(*v)
}
return _u
}
// SetChannelAppID sets the "channel_app_id" field.
func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne {
_u.mutation.SetChannelAppID(v)
return _u
}
// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne {
if v != nil {
_u.SetChannelAppID(*v)
}
return _u
}
// SetChannelSubject sets the "channel_subject" field.
func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne {
_u.mutation.SetChannelSubject(v)
return _u
}
// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne {
if v != nil {
_u.SetChannelSubject(*v)
}
return _u
}
// SetMetadata sets the "metadata" field.
func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne {
_u.mutation.SetMetadata(v)
return _u
}
// SetIdentity sets the "identity" edge to the AuthIdentity entity.
func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne {
return _u.SetIdentityID(v.ID)
}
// Mutation returns the AuthIdentityChannelMutation object of the builder.
func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation {
return _u.mutation
}
// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne {
_u.mutation.ClearIdentity()
return _u
}
// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated AuthIdentityChannel entity.
func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *AuthIdentityChannelUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := authidentitychannel.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *AuthIdentityChannelUpdateOne) check() error {
if v, ok := _u.mutation.ProviderType(); ok {
if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
}
}
if v, ok := _u.mutation.ProviderKey(); ok {
if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
}
}
if v, ok := _u.mutation.Channel(); ok {
if err := authidentitychannel.ChannelValidator(v); err != nil {
return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
}
}
if v, ok := _u.mutation.ChannelAppID(); ok {
if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
}
}
if v, ok := _u.mutation.ChannelSubject(); ok {
if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
}
}
if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
}
return nil
}
func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
for _, f := range fields {
if !authidentitychannel.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != authidentitychannel.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.ProviderType(); ok {
_spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
}
if value, ok := _u.mutation.ProviderKey(); ok {
_spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
}
if value, ok := _u.mutation.Channel(); ok {
_spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
}
if value, ok := _u.mutation.ChannelAppID(); ok {
_spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
}
if value, ok := _u.mutation.ChannelSubject(); ok {
_spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
}
if value, ok := _u.mutation.Metadata(); ok {
_spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
}
if _u.mutation.IdentityCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentitychannel.IdentityTable,
Columns: []string{authidentitychannel.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: authidentitychannel.IdentityTable,
Columns: []string{authidentitychannel.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &AuthIdentityChannel{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authidentitychannel.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@ -20,12 +20,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@ -60,18 +64,26 @@ type Client struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// AuthIdentity is the client for interacting with the AuthIdentity builders.
AuthIdentity *AuthIdentityClient
// AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
// IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
// PendingAuthSession is the client for interacting with the PendingAuthSession builders.
PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@ -118,12 +130,16 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
c.AuthIdentity = NewAuthIdentityClient(c.config)
c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config)
c.PaymentAuditLog = NewPaymentAuditLogClient(c.config)
c.PaymentOrder = NewPaymentOrderClient(c.config)
c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config)
c.PendingAuthSession = NewPendingAuthSessionClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
ctx: ctx,
config: cfg,
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg),
IdempotencyRecord: NewIdempotencyRecordClient(cfg),
PaymentAuditLog: NewPaymentAuditLogClient(cfg),
PaymentOrder: NewPaymentOrderClient(cfg),
PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
SubscriptionPlan: NewSubscriptionPlanClient(cfg),
TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
ctx: ctx,
config: cfg,
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
AuthIdentity: NewAuthIdentityClient(cfg),
AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg),
IdempotencyRecord: NewIdempotencyRecordClient(cfg),
IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
PaymentAuditLog: NewPaymentAuditLogClient(cfg),
PaymentOrder: NewPaymentOrderClient(cfg),
PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
PendingAuthSession: NewPendingAuthSessionClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
SubscriptionPlan: NewSubscriptionPlanClient(cfg),
TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
ctx: ctx,
config: cfg,
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg),
IdempotencyRecord: NewIdempotencyRecordClient(cfg),
PaymentAuditLog: NewPaymentAuditLogClient(cfg),
PaymentOrder: NewPaymentOrderClient(cfg),
PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
SubscriptionPlan: NewSubscriptionPlanClient(cfg),
TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
ctx: ctx,
config: cfg,
APIKey: NewAPIKeyClient(cfg),
Account: NewAccountClient(cfg),
AccountGroup: NewAccountGroupClient(cfg),
Announcement: NewAnnouncementClient(cfg),
AnnouncementRead: NewAnnouncementReadClient(cfg),
AuthIdentity: NewAuthIdentityClient(cfg),
AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg),
IdempotencyRecord: NewIdempotencyRecordClient(cfg),
IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
PaymentAuditLog: NewPaymentAuditLogClient(cfg),
PaymentOrder: NewPaymentOrderClient(cfg),
PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
PendingAuthSession: NewPendingAuthSessionClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
SubscriptionPlan: NewSubscriptionPlanClient(cfg),
TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@ -332,11 +356,12 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
case *AuthIdentityMutation:
return c.AuthIdentity.mutate(ctx, m)
case *AuthIdentityChannelMutation:
return c.AuthIdentityChannel.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *IdempotencyRecordMutation:
return c.IdempotencyRecord.mutate(ctx, m)
case *IdentityAdoptionDecisionMutation:
return c.IdentityAdoptionDecision.mutate(ctx, m)
case *PaymentAuditLogMutation:
return c.PaymentAuditLog.mutate(ctx, m)
case *PaymentOrderMutation:
return c.PaymentOrder.mutate(ctx, m)
case *PaymentProviderInstanceMutation:
return c.PaymentProviderInstance.mutate(ctx, m)
case *PendingAuthSessionMutation:
return c.PendingAuthSession.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
}
}
// AuthIdentityClient is a client for the AuthIdentity schema.
type AuthIdentityClient struct {
config
}
// NewAuthIdentityClient returns a client for the AuthIdentity from the given config.
func NewAuthIdentityClient(c config) *AuthIdentityClient {
return &AuthIdentityClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`.
func (c *AuthIdentityClient) Use(hooks ...Hook) {
c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`.
func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) {
c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...)
}
// Create returns a builder for creating a AuthIdentity entity.
func (c *AuthIdentityClient) Create() *AuthIdentityCreate {
mutation := newAuthIdentityMutation(c.config, OpCreate)
return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of AuthIdentity entities.
func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk {
return &AuthIdentityCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AuthIdentityCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AuthIdentityCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for AuthIdentity.
func (c *AuthIdentityClient) Update() *AuthIdentityUpdate {
mutation := newAuthIdentityMutation(c.config, OpUpdate)
return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne {
mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m))
return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne {
mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id))
return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for AuthIdentity.
func (c *AuthIdentityClient) Delete() *AuthIdentityDelete {
mutation := newAuthIdentityMutation(c.config, OpDelete)
return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne {
builder := c.Delete().Where(authidentity.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AuthIdentityDeleteOne{builder}
}
// Query returns a query builder for AuthIdentity.
func (c *AuthIdentityClient) Query() *AuthIdentityQuery {
return &AuthIdentityQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAuthIdentity},
inters: c.Interceptors(),
}
}
// Get returns a AuthIdentity entity by its id.
func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) {
return c.Query().Where(authidentity.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryUser queries the user edge of a AuthIdentity.
func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryChannels queries the channels edge of a AuthIdentity.
func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery {
query := (&AuthIdentityChannelClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity.
func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery {
query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *AuthIdentityClient) Hooks() []Hook {
return c.hooks.AuthIdentity
}
// Interceptors returns the client interceptors.
func (c *AuthIdentityClient) Interceptors() []Interceptor {
return c.inters.AuthIdentity
}
func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op())
}
}
// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema.
type AuthIdentityChannelClient struct {
config
}
// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config.
func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient {
return &AuthIdentityChannelClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`.
func (c *AuthIdentityChannelClient) Use(hooks ...Hook) {
c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`.
func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) {
c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...)
}
// Create returns a builder for creating a AuthIdentityChannel entity.
func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate {
mutation := newAuthIdentityChannelMutation(c.config, OpCreate)
return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities.
func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk {
return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AuthIdentityChannelCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for AuthIdentityChannel.
func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate {
mutation := newAuthIdentityChannelMutation(c.config, OpUpdate)
return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m))
return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne {
mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id))
return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for AuthIdentityChannel.
func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete {
mutation := newAuthIdentityChannelMutation(c.config, OpDelete)
return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne {
builder := c.Delete().Where(authidentitychannel.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AuthIdentityChannelDeleteOne{builder}
}
// Query returns a query builder for AuthIdentityChannel.
func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery {
return &AuthIdentityChannelQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAuthIdentityChannel},
inters: c.Interceptors(),
}
}
// Get returns a AuthIdentityChannel entity by its id.
func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) {
return c.Query().Where(authidentitychannel.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryIdentity queries the identity edge of a AuthIdentityChannel.
func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery {
query := (&AuthIdentityClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *AuthIdentityChannelClient) Hooks() []Hook {
return c.hooks.AuthIdentityChannel
}
// Interceptors returns the client interceptors.
func (c *AuthIdentityChannelClient) Interceptors() []Interceptor {
return c.inters.AuthIdentityChannel
}
func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op())
}
}
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco
}
}
// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema.
type IdentityAdoptionDecisionClient struct {
config
}
// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config.
func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient {
return &IdentityAdoptionDecisionClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`.
func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) {
c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`.
func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) {
c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...)
}
// Create returns a builder for creating a IdentityAdoptionDecision entity.
func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate {
mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate)
return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities.
func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk {
return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*IdentityAdoptionDecisionCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for IdentityAdoptionDecision.
func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate {
mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate)
return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m))
return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne {
mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id))
return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for IdentityAdoptionDecision.
func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete {
mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete)
return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne {
builder := c.Delete().Where(identityadoptiondecision.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &IdentityAdoptionDecisionDeleteOne{builder}
}
// Query returns a query builder for IdentityAdoptionDecision.
func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery {
return &IdentityAdoptionDecisionQuery{
config: c.config,
ctx: &QueryContext{Type: TypeIdentityAdoptionDecision},
inters: c.Interceptors(),
}
}
// Get returns a IdentityAdoptionDecision entity by its id.
func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) {
return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision.
func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery {
query := (&PendingAuthSessionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryIdentity queries the identity edge of a IdentityAdoptionDecision.
func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery {
query := (&AuthIdentityClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *IdentityAdoptionDecisionClient) Hooks() []Hook {
return c.hooks.IdentityAdoptionDecision
}
// Interceptors returns the client interceptors.
func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor {
return c.inters.IdentityAdoptionDecision
}
func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op())
}
}
// PaymentAuditLogClient is a client for the PaymentAuditLog schema.
type PaymentAuditLogClient struct {
config
@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr
}
}
// PendingAuthSessionClient is a client for the PendingAuthSession schema.
type PendingAuthSessionClient struct {
config
}
// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config.
func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient {
return &PendingAuthSessionClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`.
func (c *PendingAuthSessionClient) Use(hooks ...Hook) {
c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`.
func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) {
c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...)
}
// Create returns a builder for creating a PendingAuthSession entity.
func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate {
mutation := newPendingAuthSessionMutation(c.config, OpCreate)
return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities.
func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk {
return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*PendingAuthSessionCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for PendingAuthSession.
func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate {
mutation := newPendingAuthSessionMutation(c.config, OpUpdate)
return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne {
mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m))
return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne {
mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id))
return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for PendingAuthSession.
func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete {
mutation := newPendingAuthSessionMutation(c.config, OpDelete)
return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne {
builder := c.Delete().Where(pendingauthsession.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &PendingAuthSessionDeleteOne{builder}
}
// Query returns a query builder for PendingAuthSession.
func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery {
return &PendingAuthSessionQuery{
config: c.config,
ctx: &QueryContext{Type: TypePendingAuthSession},
inters: c.Interceptors(),
}
}
// Get returns a PendingAuthSession entity by its id.
func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) {
return c.Query().Where(pendingauthsession.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryTargetUser queries the target_user edge of a PendingAuthSession.
func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession.
func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery {
query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *PendingAuthSessionClient) Hooks() []Hook {
return c.hooks.PendingAuthSession
}
// Interceptors returns the client interceptors.
func (c *PendingAuthSessionClient) Interceptors() []Interceptor {
return c.inters.PendingAuthSession
}
func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op())
}
}
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery {
return query
}
// QueryAuthIdentities queries the auth_identities edge of a User.
func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery {
query := (&AuthIdentityClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User.
func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery {
query := (&PendingAuthSessionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Interceptor
}

View File

@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@ -98,32 +102,36 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
apikey.Table: apikey.ValidColumn,
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
announcement.Table: announcement.ValidColumn,
announcementread.Table: announcementread.ValidColumn,
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
group.Table: group.ValidColumn,
idempotencyrecord.Table: idempotencyrecord.ValidColumn,
paymentauditlog.Table: paymentauditlog.ValidColumn,
paymentorder.Table: paymentorder.ValidColumn,
paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn,
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
securitysecret.Table: securitysecret.ValidColumn,
setting.Table: setting.ValidColumn,
subscriptionplan.Table: subscriptionplan.ValidColumn,
tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,
userattributedefinition.Table: userattributedefinition.ValidColumn,
userattributevalue.Table: userattributevalue.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
apikey.Table: apikey.ValidColumn,
account.Table: account.ValidColumn,
accountgroup.Table: accountgroup.ValidColumn,
announcement.Table: announcement.ValidColumn,
announcementread.Table: announcementread.ValidColumn,
authidentity.Table: authidentity.ValidColumn,
authidentitychannel.Table: authidentitychannel.ValidColumn,
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
group.Table: group.ValidColumn,
idempotencyrecord.Table: idempotencyrecord.ValidColumn,
identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
paymentauditlog.Table: paymentauditlog.ValidColumn,
paymentorder.Table: paymentorder.ValidColumn,
paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
pendingauthsession.Table: pendingauthsession.ValidColumn,
promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn,
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
securitysecret.Table: securitysecret.ValidColumn,
setting.Table: setting.ValidColumn,
subscriptionplan.Table: subscriptionplan.ValidColumn,
tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,
userattributedefinition.Table: userattributedefinition.ValidColumn,
userattributevalue.Table: userattributevalue.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)

View File

@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
// The AuthIdentityFunc type is an adapter to allow the use of ordinary
// function as AuthIdentity mutator.
type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.AuthIdentityMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m)
}
// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary
// function as AuthIdentityChannel mutator.
type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m)
}
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
}
// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary
// function as IdentityAdoptionDecision mutator.
type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m)
}
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary
// function as PaymentAuditLog mutator.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error)
@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m)
}
// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary
// function as PendingAuthSession mutator.
type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.PendingAuthSessionMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m)
}
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)

View File

@ -0,0 +1,223 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
)
// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema.
type IdentityAdoptionDecision struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// PendingAuthSessionID holds the value of the "pending_auth_session_id" field.
PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"`
// IdentityID holds the value of the "identity_id" field.
IdentityID *int64 `json:"identity_id,omitempty"`
// AdoptDisplayName holds the value of the "adopt_display_name" field.
AdoptDisplayName bool `json:"adopt_display_name,omitempty"`
// AdoptAvatar holds the value of the "adopt_avatar" field.
AdoptAvatar bool `json:"adopt_avatar,omitempty"`
// DecidedAt holds the value of the "decided_at" field.
DecidedAt time.Time `json:"decided_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set.
Edges IdentityAdoptionDecisionEdges `json:"edges"`
selectValues sql.SelectValues
}
// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph.
type IdentityAdoptionDecisionEdges struct {
// PendingAuthSession holds the value of the pending_auth_session edge.
PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"`
// Identity holds the value of the identity edge.
Identity *AuthIdentity `json:"identity,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
}
// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) {
if e.PendingAuthSession != nil {
return e.PendingAuthSession, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: pendingauthsession.Label}
}
return nil, &NotLoadedError{edge: "pending_auth_session"}
}
// IdentityOrErr returns the Identity value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) {
if e.Identity != nil {
return e.Identity, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: authidentity.Label}
}
return nil, &NotLoadedError{edge: "identity"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar:
values[i] = new(sql.NullBool)
case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID:
values[i] = new(sql.NullInt64)
case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the IdentityAdoptionDecision fields.
func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case identityadoptiondecision.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case identityadoptiondecision.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case identityadoptiondecision.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case identityadoptiondecision.FieldPendingAuthSessionID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i])
} else if value.Valid {
_m.PendingAuthSessionID = value.Int64
}
case identityadoptiondecision.FieldIdentityID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field identity_id", values[i])
} else if value.Valid {
_m.IdentityID = new(int64)
*_m.IdentityID = value.Int64
}
case identityadoptiondecision.FieldAdoptDisplayName:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i])
} else if value.Valid {
_m.AdoptDisplayName = value.Bool
}
case identityadoptiondecision.FieldAdoptAvatar:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i])
} else if value.Valid {
_m.AdoptAvatar = value.Bool
}
case identityadoptiondecision.FieldDecidedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field decided_at", values[i])
} else if value.Valid {
_m.DecidedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision.
// This includes values selected through modifiers, order, etc.
func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity.
func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery {
return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m)
}
// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity.
func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery {
return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m)
}
// Update returns a builder for updating this IdentityAdoptionDecision.
// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne {
return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: IdentityAdoptionDecision is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *IdentityAdoptionDecision) String() string {
var builder strings.Builder
builder.WriteString("IdentityAdoptionDecision(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("pending_auth_session_id=")
builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID))
builder.WriteString(", ")
if v := _m.IdentityID; v != nil {
builder.WriteString("identity_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("adopt_display_name=")
builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName))
builder.WriteString(", ")
builder.WriteString("adopt_avatar=")
builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar))
builder.WriteString(", ")
builder.WriteString("decided_at=")
builder.WriteString(_m.DecidedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision.
type IdentityAdoptionDecisions []*IdentityAdoptionDecision

View File

@ -0,0 +1,159 @@
// Code generated by ent, DO NOT EDIT.
package identityadoptiondecision
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the identityadoptiondecision type in the database.
Label = "identity_adoption_decision"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database.
FieldPendingAuthSessionID = "pending_auth_session_id"
// FieldIdentityID holds the string denoting the identity_id field in the database.
FieldIdentityID = "identity_id"
// FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database.
FieldAdoptDisplayName = "adopt_display_name"
// FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database.
FieldAdoptAvatar = "adopt_avatar"
// FieldDecidedAt holds the string denoting the decided_at field in the database.
FieldDecidedAt = "decided_at"
// EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations.
EdgePendingAuthSession = "pending_auth_session"
// EdgeIdentity holds the string denoting the identity edge name in mutations.
EdgeIdentity = "identity"
// Table holds the table name of the identityadoptiondecision in the database.
Table = "identity_adoption_decisions"
// PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge.
PendingAuthSessionTable = "identity_adoption_decisions"
// PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity.
// It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
PendingAuthSessionInverseTable = "pending_auth_sessions"
// PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge.
PendingAuthSessionColumn = "pending_auth_session_id"
// IdentityTable is the table that holds the identity relation/edge.
IdentityTable = "identity_adoption_decisions"
// IdentityInverseTable is the table name for the AuthIdentity entity.
// It exists in this package in order to avoid circular dependency with the "authidentity" package.
IdentityInverseTable = "auth_identities"
// IdentityColumn is the table column denoting the identity relation/edge.
IdentityColumn = "identity_id"
)
// Columns holds all SQL columns for identityadoptiondecision fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldPendingAuthSessionID,
FieldIdentityID,
FieldAdoptDisplayName,
FieldAdoptAvatar,
FieldDecidedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field.
DefaultAdoptDisplayName bool
// DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field.
DefaultAdoptAvatar bool
// DefaultDecidedAt holds the default value on creation for the "decided_at" field.
DefaultDecidedAt func() time.Time
)
// OrderOption defines the ordering options for the IdentityAdoptionDecision queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByPendingAuthSessionID orders the results by the pending_auth_session_id field.
func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc()
}
// ByIdentityID orders the results by the identity_id field.
func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
}
// ByAdoptDisplayName orders the results by the adopt_display_name field.
func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc()
}
// ByAdoptAvatar orders the results by the adopt_avatar field.
func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc()
}
// ByDecidedAt orders the results by the decided_at field.
func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDecidedAt, opts...).ToFunc()
}
// ByPendingAuthSessionField orders the results by pending_auth_session field.
func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...))
}
}
// ByIdentityField orders the results by identity field.
func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
}
}
func newPendingAuthSessionStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PendingAuthSessionInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
)
}
func newIdentityStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(IdentityInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
)
}

View File

@ -0,0 +1,342 @@
// Code generated by ent, DO NOT EDIT.
package identityadoptiondecision
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
}
// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ.
func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
}
// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
func IdentityID(v int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
}
// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ.
func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
}
// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ.
func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
}
// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ.
func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v))
}
// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field.
func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
}
// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field.
func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v))
}
// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field.
func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...))
}
// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field.
func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...))
}
// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
}
// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v))
}
// IdentityIDIn applies the In predicate on the "identity_id" field.
func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...))
}
// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...))
}
// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field.
func IdentityIDIsNil() predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID))
}
// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field.
func IdentityIDNotNil() predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID))
}
// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field.
func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
}
// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field.
func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v))
}
// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field.
func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
}
// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field.
func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v))
}
// DecidedAtEQ applies the EQ predicate on the "decided_at" field.
func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
}
// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field.
func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v))
}
// DecidedAtIn applies the In predicate on the "decided_at" field.
func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...))
}
// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field.
func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...))
}
// DecidedAtGT applies the GT predicate on the "decided_at" field.
func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v))
}
// DecidedAtGTE applies the GTE predicate on the "decided_at" field.
func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v))
}
// DecidedAtLT applies the LT predicate on the "decided_at" field.
func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v))
}
// DecidedAtLTE applies the LTE predicate on the "decided_at" field.
func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v))
}
// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge.
func HasPendingAuthSession() predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates).
func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
step := newPendingAuthSessionStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasIdentity applies the HasEdge predicate on the "identity" edge.
func HasIdentity() predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
step := newIdentityStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
return predicate.IdentityAdoptionDecision(sql.NotPredicates(p))
}

View File

@ -0,0 +1,843 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
)
// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity.
type IdentityAdoptionDecisionCreate struct {
config
mutation *IdentityAdoptionDecisionMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate {
_c.mutation.SetPendingAuthSessionID(v)
return _c
}
// SetIdentityID sets the "identity_id" field.
func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate {
_c.mutation.SetIdentityID(v)
return _c
}
// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate {
if v != nil {
_c.SetIdentityID(*v)
}
return _c
}
// SetAdoptDisplayName sets the "adopt_display_name" field.
func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate {
_c.mutation.SetAdoptDisplayName(v)
return _c
}
// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate {
if v != nil {
_c.SetAdoptDisplayName(*v)
}
return _c
}
// SetAdoptAvatar sets the "adopt_avatar" field.
func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate {
_c.mutation.SetAdoptAvatar(v)
return _c
}
// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate {
if v != nil {
_c.SetAdoptAvatar(*v)
}
return _c
}
// SetDecidedAt sets the "decided_at" field.
func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate {
_c.mutation.SetDecidedAt(v)
return _c
}
// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil.
func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
if v != nil {
_c.SetDecidedAt(*v)
}
return _c
}
// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate {
return _c.SetPendingAuthSessionID(v.ID)
}
// SetIdentity sets the "identity" edge to the AuthIdentity entity.
func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate {
return _c.SetIdentityID(v.ID)
}
// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation {
return _c.mutation
}
// Save creates the IdentityAdoptionDecision in the database.
func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *IdentityAdoptionDecisionCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := identityadoptiondecision.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := identityadoptiondecision.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.AdoptDisplayName(); !ok {
v := identityadoptiondecision.DefaultAdoptDisplayName
_c.mutation.SetAdoptDisplayName(v)
}
if _, ok := _c.mutation.AdoptAvatar(); !ok {
v := identityadoptiondecision.DefaultAdoptAvatar
_c.mutation.SetAdoptAvatar(v)
}
if _, ok := _c.mutation.DecidedAt(); !ok {
v := identityadoptiondecision.DefaultDecidedAt()
_c.mutation.SetDecidedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *IdentityAdoptionDecisionCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)}
}
if _, ok := _c.mutation.PendingAuthSessionID(); !ok {
return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)}
}
if _, ok := _c.mutation.AdoptDisplayName(); !ok {
return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)}
}
if _, ok := _c.mutation.AdoptAvatar(); !ok {
return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)}
}
if _, ok := _c.mutation.DecidedAt(); !ok {
return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)}
}
if len(_c.mutation.PendingAuthSessionIDs()) == 0 {
return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)}
}
return nil
}
func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) {
var (
_node = &IdentityAdoptionDecision{config: _c.config}
_spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.AdoptDisplayName(); ok {
_spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
_node.AdoptDisplayName = value
}
if value, ok := _c.mutation.AdoptAvatar(); ok {
_spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
_node.AdoptAvatar = value
}
if value, ok := _c.mutation.DecidedAt(); ok {
_spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value)
_node.DecidedAt = value
}
if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2O,
Inverse: true,
Table: identityadoptiondecision.PendingAuthSessionTable,
Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.PendingAuthSessionID = nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: identityadoptiondecision.IdentityTable,
Columns: []string{identityadoptiondecision.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_node.IdentityID = &nodes[0]
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.IdentityAdoptionDecision.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne {
_c.conflict = opts
return &IdentityAdoptionDecisionUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.IdentityAdoptionDecision.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &IdentityAdoptionDecisionUpsertOne{
create: _c,
}
}
type (
// IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing
// one IdentityAdoptionDecision node.
IdentityAdoptionDecisionUpsertOne struct {
create *IdentityAdoptionDecisionCreate
}
// IdentityAdoptionDecisionUpsert is the "OnConflict" setter.
IdentityAdoptionDecisionUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert {
u.Set(identityadoptiondecision.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert {
u.SetExcluded(identityadoptiondecision.FieldUpdatedAt)
return u
}
// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert {
u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v)
return u
}
// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert {
u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID)
return u
}
// SetIdentityID sets the "identity_id" field.
func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert {
u.Set(identityadoptiondecision.FieldIdentityID, v)
return u
}
// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert {
u.SetExcluded(identityadoptiondecision.FieldIdentityID)
return u
}
// ClearIdentityID clears the value of the "identity_id" field.
func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert {
u.SetNull(identityadoptiondecision.FieldIdentityID)
return u
}
// SetAdoptDisplayName sets the "adopt_display_name" field.
func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert {
u.Set(identityadoptiondecision.FieldAdoptDisplayName, v)
return u
}
// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert {
u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName)
return u
}
// SetAdoptAvatar sets the "adopt_avatar" field.
func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert {
u.Set(identityadoptiondecision.FieldAdoptAvatar, v)
return u
}
// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert {
u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.IdentityAdoptionDecision.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
}
if _, exists := u.create.mutation.DecidedAt(); exists {
s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.IdentityAdoptionDecision.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict
// documentation for more info.
func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateUpdatedAt()
})
}
// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetPendingAuthSessionID(v)
})
}
// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdatePendingAuthSessionID()
})
}
// SetIdentityID sets the "identity_id" field.
func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetIdentityID(v)
})
}
// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateIdentityID()
})
}
// ClearIdentityID clears the value of the "identity_id" field.
func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.ClearIdentityID()
})
}
// SetAdoptDisplayName sets the "adopt_display_name" field.
func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetAdoptDisplayName(v)
})
}
// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateAdoptDisplayName()
})
}
// SetAdoptAvatar sets the "adopt_avatar" field.
func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetAdoptAvatar(v)
})
}
// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateAdoptAvatar()
})
}
// Exec executes the query.
func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk.
type IdentityAdoptionDecisionCreateBulk struct {
config
err error
builders []*IdentityAdoptionDecisionCreate
conflict []sql.ConflictOption
}
// Save creates the IdentityAdoptionDecision entities in the database.
func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*IdentityAdoptionDecision, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*IdentityAdoptionDecisionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.IdentityAdoptionDecision.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk {
_c.conflict = opts
return &IdentityAdoptionDecisionUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.IdentityAdoptionDecision.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &IdentityAdoptionDecisionUpsertBulk{
create: _c,
}
}
// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing
// a bulk of IdentityAdoptionDecision nodes.
type IdentityAdoptionDecisionUpsertBulk struct {
create *IdentityAdoptionDecisionCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.IdentityAdoptionDecision.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
}
if _, exists := b.mutation.DecidedAt(); exists {
s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.IdentityAdoptionDecision.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict
// documentation for more info.
func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateUpdatedAt()
})
}
// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetPendingAuthSessionID(v)
})
}
// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdatePendingAuthSessionID()
})
}
// SetIdentityID sets the "identity_id" field.
func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetIdentityID(v)
})
}
// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateIdentityID()
})
}
// ClearIdentityID clears the value of the "identity_id" field.
func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.ClearIdentityID()
})
}
// SetAdoptDisplayName sets the "adopt_display_name" field.
func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetAdoptDisplayName(v)
})
}
// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateAdoptDisplayName()
})
}
// SetAdoptAvatar sets the "adopt_avatar" field.
func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.SetAdoptAvatar(v)
})
}
// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk {
return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
s.UpdateAdoptAvatar()
})
}
// Exec executes the query.
func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity.
type IdentityAdoptionDecisionDelete struct {
config
hooks []Hook
mutation *IdentityAdoptionDecisionMutation
}
// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity.
type IdentityAdoptionDecisionDeleteOne struct {
_d *IdentityAdoptionDecisionDelete
}
// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{identityadoptiondecision.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,721 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities.
type IdentityAdoptionDecisionQuery struct {
config
ctx *QueryContext
order []identityadoptiondecision.OrderOption
inters []Interceptor
predicates []predicate.IdentityAdoptionDecision
withPendingAuthSession *PendingAuthSessionQuery
withIdentity *AuthIdentityQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder.
func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge.
func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryIdentity chains the current query on the "identity" edge.
func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first IdentityAdoptionDecision entity from the query.
// Returns a *NotFoundError when no IdentityAdoptionDecision was found.
func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{identityadoptiondecision.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first IdentityAdoptionDecision ID from the query.
// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found.
func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{identityadoptiondecision.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found.
// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found.
func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{identityadoptiondecision.Label}
default:
return nil, &NotSingularError{identityadoptiondecision.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query.
// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{identityadoptiondecision.Label}
default:
err = &NotSingularError{identityadoptiondecision.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of IdentityAdoptionDecisions.
func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]()
return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of IdentityAdoptionDecision IDs.
func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery {
if _q == nil {
return nil
}
return &IdentityAdoptionDecisionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]identityadoptiondecision.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...),
withPendingAuthSession: _q.withPendingAuthSession.Clone(),
withIdentity: _q.withIdentity.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to
// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPendingAuthSession = query
return _q
}
// WithIdentity tells the query-builder to eager-load the nodes that are connected to
// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withIdentity = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.IdentityAdoptionDecision.Query().
// GroupBy(identityadoptiondecision.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &IdentityAdoptionDecisionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = identityadoptiondecision.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.IdentityAdoptionDecision.Query().
// Select(identityadoptiondecision.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q}
sbuild.label = identityadoptiondecision.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations.
func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !identityadoptiondecision.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) {
var (
nodes = []*IdentityAdoptionDecision{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
_q.withPendingAuthSession != nil,
_q.withIdentity != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*IdentityAdoptionDecision).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &IdentityAdoptionDecision{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withPendingAuthSession; query != nil {
if err := _q.loadPendingAuthSession(ctx, query, nodes, nil,
func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil {
return nil, err
}
}
if query := _q.withIdentity; query != nil {
if err := _q.loadIdentity(ctx, query, nodes, nil,
func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*IdentityAdoptionDecision)
for i := range nodes {
fk := nodes[i].PendingAuthSessionID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(pendingauthsession.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*IdentityAdoptionDecision)
for i := range nodes {
if nodes[i].IdentityID == nil {
continue
}
fk := *nodes[i].IdentityID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(authidentity.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
for i := range fields {
if fields[i] != identityadoptiondecision.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withPendingAuthSession != nil {
_spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID)
}
if _q.withIdentity != nil {
_spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(identityadoptiondecision.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = identityadoptiondecision.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities.
type IdentityAdoptionDecisionGroupBy struct {
selector
build *IdentityAdoptionDecisionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities.
type IdentityAdoptionDecisionSelect struct {
*IdentityAdoptionDecisionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v)
}
func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@ -0,0 +1,532 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities.
type IdentityAdoptionDecisionUpdate struct {
config
hooks []Hook
mutation *IdentityAdoptionDecisionMutation
}
// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate {
_u.mutation.SetPendingAuthSessionID(v)
return _u
}
// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate {
if v != nil {
_u.SetPendingAuthSessionID(*v)
}
return _u
}
// SetIdentityID sets the "identity_id" field.
func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate {
_u.mutation.SetIdentityID(v)
return _u
}
// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate {
if v != nil {
_u.SetIdentityID(*v)
}
return _u
}
// ClearIdentityID clears the value of the "identity_id" field.
func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate {
_u.mutation.ClearIdentityID()
return _u
}
// SetAdoptDisplayName sets the "adopt_display_name" field.
func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate {
_u.mutation.SetAdoptDisplayName(v)
return _u
}
// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate {
if v != nil {
_u.SetAdoptDisplayName(*v)
}
return _u
}
// SetAdoptAvatar sets the "adopt_avatar" field.
func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate {
_u.mutation.SetAdoptAvatar(v)
return _u
}
// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate {
if v != nil {
_u.SetAdoptAvatar(*v)
}
return _u
}
// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate {
return _u.SetPendingAuthSessionID(v.ID)
}
// SetIdentity sets the "identity" edge to the AuthIdentity entity.
func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate {
return _u.SetIdentityID(v.ID)
}
// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation {
return _u.mutation
}
// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate {
_u.mutation.ClearPendingAuthSession()
return _u
}
// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate {
_u.mutation.ClearIdentity()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *IdentityAdoptionDecisionUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := identityadoptiondecision.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *IdentityAdoptionDecisionUpdate) check() error {
if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
}
return nil
}
func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.AdoptDisplayName(); ok {
_spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
}
if value, ok := _u.mutation.AdoptAvatar(); ok {
_spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
}
if _u.mutation.PendingAuthSessionCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2O,
Inverse: true,
Table: identityadoptiondecision.PendingAuthSessionTable,
Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2O,
Inverse: true,
Table: identityadoptiondecision.PendingAuthSessionTable,
Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.IdentityCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: identityadoptiondecision.IdentityTable,
Columns: []string{identityadoptiondecision.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: identityadoptiondecision.IdentityTable,
Columns: []string{identityadoptiondecision.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{identityadoptiondecision.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity.
type IdentityAdoptionDecisionUpdateOne struct {
config
fields []string
hooks []Hook
mutation *IdentityAdoptionDecisionMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne {
_u.mutation.SetPendingAuthSessionID(v)
return _u
}
// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne {
if v != nil {
_u.SetPendingAuthSessionID(*v)
}
return _u
}
// SetIdentityID sets the "identity_id" field.
func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne {
_u.mutation.SetIdentityID(v)
return _u
}
// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne {
if v != nil {
_u.SetIdentityID(*v)
}
return _u
}
// ClearIdentityID clears the value of the "identity_id" field.
func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne {
_u.mutation.ClearIdentityID()
return _u
}
// SetAdoptDisplayName sets the "adopt_display_name" field.
func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne {
_u.mutation.SetAdoptDisplayName(v)
return _u
}
// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne {
if v != nil {
_u.SetAdoptDisplayName(*v)
}
return _u
}
// SetAdoptAvatar sets the "adopt_avatar" field.
func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne {
_u.mutation.SetAdoptAvatar(v)
return _u
}
// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne {
if v != nil {
_u.SetAdoptAvatar(*v)
}
return _u
}
// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne {
return _u.SetPendingAuthSessionID(v.ID)
}
// SetIdentity sets the "identity" edge to the AuthIdentity entity.
func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne {
return _u.SetIdentityID(v.ID)
}
// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation {
return _u.mutation
}
// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne {
_u.mutation.ClearPendingAuthSession()
return _u
}
// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne {
_u.mutation.ClearIdentity()
return _u
}
// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated IdentityAdoptionDecision entity.
func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *IdentityAdoptionDecisionUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := identityadoptiondecision.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *IdentityAdoptionDecisionUpdateOne) check() error {
if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
}
return nil
}
func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
for _, f := range fields {
if !identityadoptiondecision.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != identityadoptiondecision.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.AdoptDisplayName(); ok {
_spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
}
if value, ok := _u.mutation.AdoptAvatar(); ok {
_spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
}
if _u.mutation.PendingAuthSessionCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2O,
Inverse: true,
Table: identityadoptiondecision.PendingAuthSessionTable,
Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2O,
Inverse: true,
Table: identityadoptiondecision.PendingAuthSessionTable,
Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.IdentityCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: identityadoptiondecision.IdentityTable,
Columns: []string{identityadoptiondecision.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: identityadoptiondecision.IdentityTable,
Columns: []string{identityadoptiondecision.IdentityColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &IdentityAdoptionDecision{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{identityadoptiondecision.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@ -13,12 +13,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier.
type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.AuthIdentityQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
}
// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser.
type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.AuthIdentityQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
}
// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier.
type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
}
// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser.
type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
}
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
}
// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier.
type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
}
// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser.
type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
}
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error)
@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
}
// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier.
type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
}
// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser.
type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
}
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
case *ent.AuthIdentityQuery:
return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil
case *ent.AuthIdentityChannelQuery:
return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.IdempotencyRecordQuery:
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
case *ent.IdentityAdoptionDecisionQuery:
return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil
case *ent.PaymentAuditLogQuery:
return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil
case *ent.PaymentOrderQuery:
return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil
case *ent.PaymentProviderInstanceQuery:
return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil
case *ent.PendingAuthSessionQuery:
return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:

View File

@ -338,6 +338,89 @@ var (
},
},
}
// AuthIdentitiesColumns holds the columns for the "auth_identities" table.
AuthIdentitiesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "provider_type", Type: field.TypeString, Size: 20},
{Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "user_id", Type: field.TypeInt64},
}
// AuthIdentitiesTable holds the schema information for the "auth_identities" table.
AuthIdentitiesTable = &schema.Table{
Name: "auth_identities",
Columns: AuthIdentitiesColumns,
PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "auth_identities_users_auth_identities",
Columns: []*schema.Column{AuthIdentitiesColumns[9]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "authidentity_provider_type_provider_key_provider_subject",
Unique: true,
Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]},
},
{
Name: "authidentity_user_id",
Unique: false,
Columns: []*schema.Column{AuthIdentitiesColumns[9]},
},
{
Name: "authidentity_user_id_provider_type",
Unique: false,
Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]},
},
},
}
// AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table.
AuthIdentityChannelsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "provider_type", Type: field.TypeString, Size: 20},
{Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "channel", Type: field.TypeString, Size: 20},
{Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "identity_id", Type: field.TypeInt64},
}
// AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table.
AuthIdentityChannelsTable = &schema.Table{
Name: "auth_identity_channels",
Columns: AuthIdentityChannelsColumns,
PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "auth_identity_channels_auth_identities_channels",
Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject",
Unique: true,
Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]},
},
{
Name: "authidentitychannel_identity_id",
Unique: false,
Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
},
},
}
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@ -485,6 +568,49 @@ var (
},
},
}
// IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table.
IdentityAdoptionDecisionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "adopt_display_name", Type: field.TypeBool, Default: false},
{Name: "adopt_avatar", Type: field.TypeBool, Default: false},
{Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "identity_id", Type: field.TypeInt64, Nullable: true},
{Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true},
}
// IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table.
IdentityAdoptionDecisionsTable = &schema.Table{
Name: "identity_adoption_decisions",
Columns: IdentityAdoptionDecisionsColumns,
PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions",
Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "identityadoptiondecision_pending_auth_session_id",
Unique: true,
Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
},
{
Name: "identityadoptiondecision_identity_id",
Unique: false,
Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
},
},
}
// PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table.
PaymentAuditLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@ -638,6 +764,72 @@ var (
},
},
}
// PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table.
PendingAuthSessionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "session_token", Type: field.TypeString, Size: 255},
{Name: "intent", Type: field.TypeString, Size: 40},
{Name: "provider_type", Type: field.TypeString, Size: 20},
{Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
{Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "target_user_id", Type: field.TypeInt64, Nullable: true},
}
// PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table.
PendingAuthSessionsTable = &schema.Table{
Name: "pending_auth_sessions",
Columns: PendingAuthSessionsColumns,
PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "pending_auth_sessions_users_pending_auth_sessions",
Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.SetNull,
},
},
Indexes: []*schema.Index{
{
Name: "pendingauthsession_session_token",
Unique: true,
Columns: []*schema.Column{PendingAuthSessionsColumns[3]},
},
{
Name: "pendingauthsession_target_user_id",
Unique: false,
Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
},
{
Name: "pendingauthsession_expires_at",
Unique: false,
Columns: []*schema.Column{PendingAuthSessionsColumns[19]},
},
{
Name: "pendingauthsession_provider_type_provider_key_provider_subject",
Unique: false,
Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]},
},
{
Name: "pendingauthsession_completion_code_hash",
Unique: false,
Columns: []*schema.Column{PendingAuthSessionsColumns[14]},
},
},
}
// PromoCodesColumns holds the columns for the "promo_codes" table.
PromoCodesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@ -1079,6 +1271,9 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
{Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"},
{Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@ -1318,12 +1513,16 @@ var (
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
AuthIdentitiesTable,
AuthIdentityChannelsTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
IdentityAdoptionDecisionsTable,
PaymentAuditLogsTable,
PaymentOrdersTable,
PaymentProviderInstancesTable,
PendingAuthSessionsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
@ -1365,6 +1564,14 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable
AuthIdentitiesTable.Annotation = &entsql.Annotation{
Table: "auth_identities",
}
AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
Table: "auth_identity_channels",
}
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
@ -1374,6 +1581,11 @@ func init() {
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
Table: "idempotency_records",
}
IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable
IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{
Table: "identity_adoption_decisions",
}
PaymentAuditLogsTable.Annotation = &entsql.Annotation{
Table: "payment_audit_logs",
}
@ -1384,6 +1596,10 @@ func init() {
PaymentProviderInstancesTable.Annotation = &entsql.Annotation{
Table: "payment_provider_instances",
}
PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable
PendingAuthSessionsTable.Annotation = &entsql.Annotation{
Table: "pending_auth_sessions",
}
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,399 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// PendingAuthSession is the model entity for the PendingAuthSession schema.
type PendingAuthSession struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// SessionToken holds the value of the "session_token" field.
SessionToken string `json:"session_token,omitempty"`
// Intent holds the value of the "intent" field.
Intent string `json:"intent,omitempty"`
// ProviderType holds the value of the "provider_type" field.
ProviderType string `json:"provider_type,omitempty"`
// ProviderKey holds the value of the "provider_key" field.
ProviderKey string `json:"provider_key,omitempty"`
// ProviderSubject holds the value of the "provider_subject" field.
ProviderSubject string `json:"provider_subject,omitempty"`
// TargetUserID holds the value of the "target_user_id" field.
TargetUserID *int64 `json:"target_user_id,omitempty"`
// RedirectTo holds the value of the "redirect_to" field.
RedirectTo string `json:"redirect_to,omitempty"`
// ResolvedEmail holds the value of the "resolved_email" field.
ResolvedEmail string `json:"resolved_email,omitempty"`
// RegistrationPasswordHash holds the value of the "registration_password_hash" field.
RegistrationPasswordHash string `json:"registration_password_hash,omitempty"`
// UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field.
UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"`
// LocalFlowState holds the value of the "local_flow_state" field.
LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"`
// BrowserSessionKey holds the value of the "browser_session_key" field.
BrowserSessionKey string `json:"browser_session_key,omitempty"`
// CompletionCodeHash holds the value of the "completion_code_hash" field.
CompletionCodeHash string `json:"completion_code_hash,omitempty"`
// CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field.
CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"`
// EmailVerifiedAt holds the value of the "email_verified_at" field.
EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
// PasswordVerifiedAt holds the value of the "password_verified_at" field.
PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"`
// TotpVerifiedAt holds the value of the "totp_verified_at" field.
TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"`
// ExpiresAt holds the value of the "expires_at" field.
ExpiresAt time.Time `json:"expires_at,omitempty"`
// ConsumedAt holds the value of the "consumed_at" field.
ConsumedAt *time.Time `json:"consumed_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the PendingAuthSessionQuery when eager-loading is set.
Edges PendingAuthSessionEdges `json:"edges"`
selectValues sql.SelectValues
}
// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph.
type PendingAuthSessionEdges struct {
// TargetUser holds the value of the target_user edge.
TargetUser *User `json:"target_user,omitempty"`
// AdoptionDecision holds the value of the adoption_decision edge.
AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [2]bool
}
// TargetUserOrErr returns the TargetUser value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) {
if e.TargetUser != nil {
return e.TargetUser, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "target_user"}
}
// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) {
if e.AdoptionDecision != nil {
return e.AdoptionDecision, nil
} else if e.loadedTypes[1] {
return nil, &NotFoundError{label: identityadoptiondecision.Label}
}
return nil, &NotLoadedError{edge: "adoption_decision"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*PendingAuthSession) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState:
values[i] = new([]byte)
case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID:
values[i] = new(sql.NullInt64)
case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash:
values[i] = new(sql.NullString)
case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the PendingAuthSession fields.
func (_m *PendingAuthSession) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case pendingauthsession.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case pendingauthsession.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case pendingauthsession.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case pendingauthsession.FieldSessionToken:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field session_token", values[i])
} else if value.Valid {
_m.SessionToken = value.String
}
case pendingauthsession.FieldIntent:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field intent", values[i])
} else if value.Valid {
_m.Intent = value.String
}
case pendingauthsession.FieldProviderType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_type", values[i])
} else if value.Valid {
_m.ProviderType = value.String
}
case pendingauthsession.FieldProviderKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_key", values[i])
} else if value.Valid {
_m.ProviderKey = value.String
}
case pendingauthsession.FieldProviderSubject:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
} else if value.Valid {
_m.ProviderSubject = value.String
}
case pendingauthsession.FieldTargetUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field target_user_id", values[i])
} else if value.Valid {
_m.TargetUserID = new(int64)
*_m.TargetUserID = value.Int64
}
case pendingauthsession.FieldRedirectTo:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field redirect_to", values[i])
} else if value.Valid {
_m.RedirectTo = value.String
}
case pendingauthsession.FieldResolvedEmail:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field resolved_email", values[i])
} else if value.Valid {
_m.ResolvedEmail = value.String
}
case pendingauthsession.FieldRegistrationPasswordHash:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i])
} else if value.Valid {
_m.RegistrationPasswordHash = value.String
}
case pendingauthsession.FieldUpstreamIdentityClaims:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil {
return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err)
}
}
case pendingauthsession.FieldLocalFlowState:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field local_flow_state", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil {
return fmt.Errorf("unmarshal field local_flow_state: %w", err)
}
}
case pendingauthsession.FieldBrowserSessionKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field browser_session_key", values[i])
} else if value.Valid {
_m.BrowserSessionKey = value.String
}
case pendingauthsession.FieldCompletionCodeHash:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i])
} else if value.Valid {
_m.CompletionCodeHash = value.String
}
case pendingauthsession.FieldCompletionCodeExpiresAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i])
} else if value.Valid {
_m.CompletionCodeExpiresAt = new(time.Time)
*_m.CompletionCodeExpiresAt = value.Time
}
case pendingauthsession.FieldEmailVerifiedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field email_verified_at", values[i])
} else if value.Valid {
_m.EmailVerifiedAt = new(time.Time)
*_m.EmailVerifiedAt = value.Time
}
case pendingauthsession.FieldPasswordVerifiedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field password_verified_at", values[i])
} else if value.Valid {
_m.PasswordVerifiedAt = new(time.Time)
*_m.PasswordVerifiedAt = value.Time
}
case pendingauthsession.FieldTotpVerifiedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i])
} else if value.Valid {
_m.TotpVerifiedAt = new(time.Time)
*_m.TotpVerifiedAt = value.Time
}
case pendingauthsession.FieldExpiresAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
} else if value.Valid {
_m.ExpiresAt = value.Time
}
case pendingauthsession.FieldConsumedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field consumed_at", values[i])
} else if value.Valid {
_m.ConsumedAt = new(time.Time)
*_m.ConsumedAt = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession.
// This includes values selected through modifiers, order, etc.
func (_m *PendingAuthSession) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity.
func (_m *PendingAuthSession) QueryTargetUser() *UserQuery {
return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m)
}
// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity.
func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m)
}
// Update returns a builder for updating this PendingAuthSession.
// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne {
return NewPendingAuthSessionClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *PendingAuthSession) Unwrap() *PendingAuthSession {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: PendingAuthSession is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *PendingAuthSession) String() string {
var builder strings.Builder
builder.WriteString("PendingAuthSession(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("session_token=")
builder.WriteString(_m.SessionToken)
builder.WriteString(", ")
builder.WriteString("intent=")
builder.WriteString(_m.Intent)
builder.WriteString(", ")
builder.WriteString("provider_type=")
builder.WriteString(_m.ProviderType)
builder.WriteString(", ")
builder.WriteString("provider_key=")
builder.WriteString(_m.ProviderKey)
builder.WriteString(", ")
builder.WriteString("provider_subject=")
builder.WriteString(_m.ProviderSubject)
builder.WriteString(", ")
if v := _m.TargetUserID; v != nil {
builder.WriteString("target_user_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("redirect_to=")
builder.WriteString(_m.RedirectTo)
builder.WriteString(", ")
builder.WriteString("resolved_email=")
builder.WriteString(_m.ResolvedEmail)
builder.WriteString(", ")
builder.WriteString("registration_password_hash=")
builder.WriteString(_m.RegistrationPasswordHash)
builder.WriteString(", ")
builder.WriteString("upstream_identity_claims=")
builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims))
builder.WriteString(", ")
builder.WriteString("local_flow_state=")
builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState))
builder.WriteString(", ")
builder.WriteString("browser_session_key=")
builder.WriteString(_m.BrowserSessionKey)
builder.WriteString(", ")
builder.WriteString("completion_code_hash=")
builder.WriteString(_m.CompletionCodeHash)
builder.WriteString(", ")
if v := _m.CompletionCodeExpiresAt; v != nil {
builder.WriteString("completion_code_expires_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.EmailVerifiedAt; v != nil {
builder.WriteString("email_verified_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.PasswordVerifiedAt; v != nil {
builder.WriteString("password_verified_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.TotpVerifiedAt; v != nil {
builder.WriteString("totp_verified_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("expires_at=")
builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.ConsumedAt; v != nil {
builder.WriteString("consumed_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}
// PendingAuthSessions is a parsable slice of PendingAuthSession.
type PendingAuthSessions []*PendingAuthSession

View File

@ -0,0 +1,279 @@
// Code generated by ent, DO NOT EDIT.
package pendingauthsession
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the pendingauthsession type in the database.
Label = "pending_auth_session"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldSessionToken holds the string denoting the session_token field in the database.
FieldSessionToken = "session_token"
// FieldIntent holds the string denoting the intent field in the database.
FieldIntent = "intent"
// FieldProviderType holds the string denoting the provider_type field in the database.
FieldProviderType = "provider_type"
// FieldProviderKey holds the string denoting the provider_key field in the database.
FieldProviderKey = "provider_key"
// FieldProviderSubject holds the string denoting the provider_subject field in the database.
FieldProviderSubject = "provider_subject"
// FieldTargetUserID holds the string denoting the target_user_id field in the database.
FieldTargetUserID = "target_user_id"
// FieldRedirectTo holds the string denoting the redirect_to field in the database.
FieldRedirectTo = "redirect_to"
// FieldResolvedEmail holds the string denoting the resolved_email field in the database.
FieldResolvedEmail = "resolved_email"
// FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database.
FieldRegistrationPasswordHash = "registration_password_hash"
// FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database.
FieldUpstreamIdentityClaims = "upstream_identity_claims"
// FieldLocalFlowState holds the string denoting the local_flow_state field in the database.
FieldLocalFlowState = "local_flow_state"
// FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database.
FieldBrowserSessionKey = "browser_session_key"
// FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database.
FieldCompletionCodeHash = "completion_code_hash"
// FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database.
FieldCompletionCodeExpiresAt = "completion_code_expires_at"
// FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database.
FieldEmailVerifiedAt = "email_verified_at"
// FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database.
FieldPasswordVerifiedAt = "password_verified_at"
// FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database.
FieldTotpVerifiedAt = "totp_verified_at"
// FieldExpiresAt holds the string denoting the expires_at field in the database.
FieldExpiresAt = "expires_at"
// FieldConsumedAt holds the string denoting the consumed_at field in the database.
FieldConsumedAt = "consumed_at"
// EdgeTargetUser holds the string denoting the target_user edge name in mutations.
EdgeTargetUser = "target_user"
// EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations.
EdgeAdoptionDecision = "adoption_decision"
// Table holds the table name of the pendingauthsession in the database.
Table = "pending_auth_sessions"
// TargetUserTable is the table that holds the target_user relation/edge.
TargetUserTable = "pending_auth_sessions"
// TargetUserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
TargetUserInverseTable = "users"
// TargetUserColumn is the table column denoting the target_user relation/edge.
TargetUserColumn = "target_user_id"
// AdoptionDecisionTable is the table that holds the adoption_decision relation/edge.
AdoptionDecisionTable = "identity_adoption_decisions"
// AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity.
// It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
AdoptionDecisionInverseTable = "identity_adoption_decisions"
// AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge.
AdoptionDecisionColumn = "pending_auth_session_id"
)
// Columns holds all SQL columns for pendingauthsession fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldSessionToken,
FieldIntent,
FieldProviderType,
FieldProviderKey,
FieldProviderSubject,
FieldTargetUserID,
FieldRedirectTo,
FieldResolvedEmail,
FieldRegistrationPasswordHash,
FieldUpstreamIdentityClaims,
FieldLocalFlowState,
FieldBrowserSessionKey,
FieldCompletionCodeHash,
FieldCompletionCodeExpiresAt,
FieldEmailVerifiedAt,
FieldPasswordVerifiedAt,
FieldTotpVerifiedAt,
FieldExpiresAt,
FieldConsumedAt,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
SessionTokenValidator func(string) error
// IntentValidator is a validator for the "intent" field. It is called by the builders before save.
IntentValidator func(string) error
// ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
ProviderTypeValidator func(string) error
// ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
ProviderKeyValidator func(string) error
// ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
ProviderSubjectValidator func(string) error
// DefaultRedirectTo holds the default value on creation for the "redirect_to" field.
DefaultRedirectTo string
// DefaultResolvedEmail holds the default value on creation for the "resolved_email" field.
DefaultResolvedEmail string
// DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field.
DefaultRegistrationPasswordHash string
// DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field.
DefaultUpstreamIdentityClaims func() map[string]interface{}
// DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field.
DefaultLocalFlowState func() map[string]interface{}
// DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field.
DefaultBrowserSessionKey string
// DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field.
DefaultCompletionCodeHash string
)
// OrderOption defines the ordering options for the PendingAuthSession queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// BySessionToken orders the results by the session_token field.
func BySessionToken(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSessionToken, opts...).ToFunc()
}
// ByIntent orders the results by the intent field.
func ByIntent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIntent, opts...).ToFunc()
}
// ByProviderType orders the results by the provider_type field.
func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderType, opts...).ToFunc()
}
// ByProviderKey orders the results by the provider_key field.
func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
}
// ByProviderSubject orders the results by the provider_subject field.
func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
}
// ByTargetUserID orders the results by the target_user_id field.
func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTargetUserID, opts...).ToFunc()
}
// ByRedirectTo orders the results by the redirect_to field.
func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRedirectTo, opts...).ToFunc()
}
// ByResolvedEmail orders the results by the resolved_email field.
func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc()
}
// ByRegistrationPasswordHash orders the results by the registration_password_hash field.
func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc()
}
// ByBrowserSessionKey orders the results by the browser_session_key field.
func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc()
}
// ByCompletionCodeHash orders the results by the completion_code_hash field.
func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc()
}
// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field.
func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc()
}
// ByEmailVerifiedAt orders the results by the email_verified_at field.
func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc()
}
// ByPasswordVerifiedAt orders the results by the password_verified_at field.
func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc()
}
// ByTotpVerifiedAt orders the results by the totp_verified_at field.
func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc()
}
// ByExpiresAt orders the results by the expires_at field.
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
}
// ByConsumedAt orders the results by the consumed_at field.
func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldConsumedAt, opts...).ToFunc()
}
// ByTargetUserField orders the results by target_user field.
func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...))
}
}
// ByAdoptionDecisionField orders the results by adoption_decision field.
func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...))
}
}
func newTargetUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(TargetUserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
)
}
func newAdoptionDecisionStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AdoptionDecisionInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity.
type PendingAuthSessionDelete struct {
config
hooks []Hook
mutation *PendingAuthSessionMutation
}
// Where appends a list predicates to the PendingAuthSessionDelete builder.
func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity.
type PendingAuthSessionDeleteOne struct {
_d *PendingAuthSessionDelete
}
// Where appends a list predicates to the PendingAuthSessionDelete builder.
func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{pendingauthsession.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,717 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"database/sql/driver"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
)
// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities.
type PendingAuthSessionQuery struct {
config
ctx *QueryContext
order []pendingauthsession.OrderOption
inters []Interceptor
predicates []predicate.PendingAuthSession
withTargetUser *UserQuery
withAdoptionDecision *IdentityAdoptionDecisionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the PendingAuthSessionQuery builder.
func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryTargetUser chains the current query on the "target_user" edge.
func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryAdoptionDecision chains the current query on the "adoption_decision" edge.
func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first PendingAuthSession entity from the query.
// Returns a *NotFoundError when no PendingAuthSession was found.
func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{pendingauthsession.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first PendingAuthSession ID from the query.
// Returns a *NotFoundError when no PendingAuthSession ID was found.
func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{pendingauthsession.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one PendingAuthSession entity is found.
// Returns a *NotFoundError when no PendingAuthSession entities are found.
func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{pendingauthsession.Label}
default:
return nil, &NotSingularError{pendingauthsession.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only PendingAuthSession ID in the query.
// Returns a *NotSingularError when more than one PendingAuthSession ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{pendingauthsession.Label}
default:
err = &NotSingularError{pendingauthsession.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of PendingAuthSessions.
func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]()
return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of PendingAuthSession IDs.
func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery {
if _q == nil {
return nil
}
return &PendingAuthSessionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]pendingauthsession.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.PendingAuthSession{}, _q.predicates...),
withTargetUser: _q.withTargetUser.Clone(),
withAdoptionDecision: _q.withAdoptionDecision.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithTargetUser tells the query-builder to eager-load the nodes that are connected to
// the "target_user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withTargetUser = query
return _q
}
// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to
// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery {
query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAdoptionDecision = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.PendingAuthSession.Query().
// GroupBy(pendingauthsession.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &PendingAuthSessionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = pendingauthsession.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.PendingAuthSession.Query().
// Select(pendingauthsession.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q}
sbuild.label = pendingauthsession.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations.
func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !pendingauthsession.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) {
var (
nodes = []*PendingAuthSession{}
_spec = _q.querySpec()
loadedTypes = [2]bool{
_q.withTargetUser != nil,
_q.withAdoptionDecision != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*PendingAuthSession).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &PendingAuthSession{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withTargetUser; query != nil {
if err := _q.loadTargetUser(ctx, query, nodes, nil,
func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil {
return nil, err
}
}
if query := _q.withAdoptionDecision; query != nil {
if err := _q.loadAdoptionDecision(ctx, query, nodes, nil,
func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*PendingAuthSession)
for i := range nodes {
if nodes[i].TargetUserID == nil {
continue
}
fk := *nodes[i].TargetUserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*PendingAuthSession)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID)
}
query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.PendingAuthSessionID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
for i := range fields {
if fields[i] != pendingauthsession.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withTargetUser != nil {
_spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(pendingauthsession.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = pendingauthsession.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities.
type PendingAuthSessionGroupBy struct {
selector
build *PendingAuthSessionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities.
type PendingAuthSessionSelect struct {
*PendingAuthSessionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v)
}
func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

File diff suppressed because it is too large Load Diff

View File

@ -21,6 +21,12 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
// AuthIdentity is the predicate function for authidentity builders.
type AuthIdentity func(*sql.Selector)
// AuthIdentityChannel is the predicate function for authidentitychannel builders.
type AuthIdentityChannel func(*sql.Selector)
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
@ -30,6 +36,9 @@ type Group func(*sql.Selector)
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
type IdempotencyRecord func(*sql.Selector)
// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders.
type IdentityAdoptionDecision func(*sql.Selector)
// PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector)
@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector)
// PaymentProviderInstance is the predicate function for paymentproviderinstance builders.
type PaymentProviderInstance func(*sql.Selector)
// PendingAuthSession is the predicate function for pendingauthsession builders.
type PendingAuthSession func(*sql.Selector)
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)

View File

@ -10,12 +10,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@ -309,6 +313,120 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
authidentityMixin := schema.AuthIdentity{}.Mixin()
authidentityMixinFields0 := authidentityMixin[0].Fields()
_ = authidentityMixinFields0
authidentityFields := schema.AuthIdentity{}.Fields()
_ = authidentityFields
// authidentityDescCreatedAt is the schema descriptor for created_at field.
authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor()
// authidentity.DefaultCreatedAt holds the default value on creation for the created_at field.
authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time)
// authidentityDescUpdatedAt is the schema descriptor for updated_at field.
authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor()
// authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field.
authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time)
// authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time)
// authidentityDescProviderType is the schema descriptor for provider_type field.
authidentityDescProviderType := authidentityFields[1].Descriptor()
// authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
authidentity.ProviderTypeValidator = func() func(string) error {
validators := authidentityDescProviderType.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
validators[2].(func(string) error),
}
return func(provider_type string) error {
for _, fn := range fns {
if err := fn(provider_type); err != nil {
return err
}
}
return nil
}
}()
// authidentityDescProviderKey is the schema descriptor for provider_key field.
authidentityDescProviderKey := authidentityFields[2].Descriptor()
// authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error)
// authidentityDescProviderSubject is the schema descriptor for provider_subject field.
authidentityDescProviderSubject := authidentityFields[3].Descriptor()
// authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error)
// authidentityDescMetadata is the schema descriptor for metadata field.
authidentityDescMetadata := authidentityFields[6].Descriptor()
// authidentity.DefaultMetadata holds the default value on creation for the metadata field.
authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{})
authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin()
authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields()
_ = authidentitychannelMixinFields0
authidentitychannelFields := schema.AuthIdentityChannel{}.Fields()
_ = authidentitychannelFields
// authidentitychannelDescCreatedAt is the schema descriptor for created_at field.
authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor()
// authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field.
authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time)
// authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field.
authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor()
// authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field.
authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time)
// authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time)
// authidentitychannelDescProviderType is the schema descriptor for provider_type field.
authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor()
// authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
authidentitychannel.ProviderTypeValidator = func() func(string) error {
validators := authidentitychannelDescProviderType.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
validators[2].(func(string) error),
}
return func(provider_type string) error {
for _, fn := range fns {
if err := fn(provider_type); err != nil {
return err
}
}
return nil
}
}()
// authidentitychannelDescProviderKey is the schema descriptor for provider_key field.
authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor()
// authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error)
// authidentitychannelDescChannel is the schema descriptor for channel field.
authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor()
// authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
authidentitychannel.ChannelValidator = func() func(string) error {
validators := authidentitychannelDescChannel.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(channel string) error {
for _, fn := range fns {
if err := fn(channel); err != nil {
return err
}
}
return nil
}
}()
// authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field.
authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor()
// authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error)
// authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field.
authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor()
// authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error)
// authidentitychannelDescMetadata is the schema descriptor for metadata field.
authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor()
// authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field.
authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{})
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
@ -512,6 +630,33 @@ func init() {
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin()
identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields()
_ = identityadoptiondecisionMixinFields0
identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields()
_ = identityadoptiondecisionFields
// identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field.
identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor()
// identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field.
identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time)
// identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field.
identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor()
// identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field.
identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time)
// identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time)
// identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field.
identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor()
// identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field.
identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool)
// identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field.
identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor()
// identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field.
identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool)
// identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field.
identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor()
// identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field.
identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time)
paymentauditlogFields := schema.PaymentAuditLog{}.Fields()
_ = paymentauditlogFields
// paymentauditlogDescOrderID is the schema descriptor for order_id field.
@ -682,6 +827,113 @@ func init() {
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time)
pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin()
pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields()
_ = pendingauthsessionMixinFields0
pendingauthsessionFields := schema.PendingAuthSession{}.Fields()
_ = pendingauthsessionFields
// pendingauthsessionDescCreatedAt is the schema descriptor for created_at field.
pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor()
// pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field.
pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time)
// pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field.
pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor()
// pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field.
pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time)
// pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time)
// pendingauthsessionDescSessionToken is the schema descriptor for session_token field.
pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor()
// pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
pendingauthsession.SessionTokenValidator = func() func(string) error {
validators := pendingauthsessionDescSessionToken.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(session_token string) error {
for _, fn := range fns {
if err := fn(session_token); err != nil {
return err
}
}
return nil
}
}()
// pendingauthsessionDescIntent is the schema descriptor for intent field.
pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor()
// pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save.
pendingauthsession.IntentValidator = func() func(string) error {
validators := pendingauthsessionDescIntent.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
validators[2].(func(string) error),
}
return func(intent string) error {
for _, fn := range fns {
if err := fn(intent); err != nil {
return err
}
}
return nil
}
}()
// pendingauthsessionDescProviderType is the schema descriptor for provider_type field.
pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor()
// pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
pendingauthsession.ProviderTypeValidator = func() func(string) error {
validators := pendingauthsessionDescProviderType.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
validators[2].(func(string) error),
}
return func(provider_type string) error {
for _, fn := range fns {
if err := fn(provider_type); err != nil {
return err
}
}
return nil
}
}()
// pendingauthsessionDescProviderKey is the schema descriptor for provider_key field.
pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor()
// pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error)
// pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field.
pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor()
// pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error)
// pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field.
pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor()
// pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field.
pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string)
// pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field.
pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor()
// pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field.
pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string)
// pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field.
pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor()
// pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field.
pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string)
// pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field.
pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor()
// pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field.
pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{})
// pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field.
pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor()
// pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field.
pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{})
// pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field.
pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor()
// pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field.
pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string)
// pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field.
pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor()
// pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field.
pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@ -1297,20 +1549,26 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
// userDescSignupSource is the schema descriptor for signup_source field.
userDescSignupSource := userFields[11].Descriptor()
// user.DefaultSignupSource holds the default value on creation for the signup_source field.
user.DefaultSignupSource = userDescSignupSource.Default.(string)
// user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
userDescBalanceNotifyEnabled := userFields[11].Descriptor()
userDescBalanceNotifyEnabled := userFields[14].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
userDescBalanceNotifyThresholdType := userFields[15].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
userDescBalanceNotifyExtraEmails := userFields[17].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
userDescTotalRecharged := userFields[15].Descriptor()
userDescTotalRecharged := userFields[18].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()

View File

@ -0,0 +1,93 @@
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
var authProviderTypes = map[string]struct{}{
"email": {},
"linuxdo": {},
"oidc": {},
"wechat": {},
}
func validateAuthProviderType(value string) error {
if _, ok := authProviderTypes[value]; ok {
return nil
}
return fmt.Errorf("invalid auth provider type %q", value)
}
// AuthIdentity stores the canonical login identity for an account.
type AuthIdentity struct {
ent.Schema
}
func (AuthIdentity) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "auth_identities"},
}
}
func (AuthIdentity) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (AuthIdentity) Fields() []ent.Field {
return []ent.Field{
field.Int64("user_id"),
field.String("provider_type").
MaxLen(20).
NotEmpty().
Validate(validateAuthProviderType),
field.String("provider_key").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("provider_subject").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Time("verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.String("issuer").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.JSON("metadata", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
}
}
func (AuthIdentity) Edges() []ent.Edge {
return []ent.Edge{
edge.From("user", User.Type).
Ref("auth_identities").
Field("user_id").
Required().
Unique(),
edge.To("channels", AuthIdentityChannel.Type),
edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
}
}
func (AuthIdentity) Indexes() []ent.Index {
return []ent.Index{
index.Fields("provider_type", "provider_key", "provider_subject").Unique(),
index.Fields("user_id"),
index.Fields("user_id", "provider_type"),
}
}

View File

@ -0,0 +1,72 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity.
type AuthIdentityChannel struct {
ent.Schema
}
func (AuthIdentityChannel) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "auth_identity_channels"},
}
}
func (AuthIdentityChannel) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (AuthIdentityChannel) Fields() []ent.Field {
return []ent.Field{
field.Int64("identity_id"),
field.String("provider_type").
MaxLen(20).
NotEmpty().
Validate(validateAuthProviderType),
field.String("provider_key").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("channel").
MaxLen(20).
NotEmpty(),
field.String("channel_app_id").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("channel_subject").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.JSON("metadata", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
}
}
func (AuthIdentityChannel) Edges() []ent.Edge {
return []ent.Edge{
edge.From("identity", AuthIdentity.Type).
Ref("channels").
Field("identity_id").
Required().
Unique(),
}
}
func (AuthIdentityChannel) Indexes() []ent.Index {
return []ent.Index{
index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(),
index.Fields("identity_id"),
}
}

View File

@ -0,0 +1,124 @@
package schema
import (
"testing"
"entgo.io/ent/entc/load"
"github.com/stretchr/testify/require"
)
func TestAuthIdentityFoundationSchemas(t *testing.T) {
spec, err := (&load.Config{Path: "."}).Load()
require.NoError(t, err)
schemas := map[string]*load.Schema{}
for _, schema := range spec.Schemas {
schemas[schema.Name] = schema
}
authIdentity := requireSchema(t, schemas, "AuthIdentity")
requireSchemaFields(t, authIdentity,
"user_id",
"provider_type",
"provider_key",
"provider_subject",
"verified_at",
"issuer",
"metadata",
)
requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject")
authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel")
requireSchemaFields(t, authIdentityChannel,
"identity_id",
"provider_type",
"provider_key",
"channel",
"channel_app_id",
"channel_subject",
"metadata",
)
requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject")
pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession")
requireSchemaFields(t, pendingAuthSession,
"intent",
"provider_type",
"provider_key",
"provider_subject",
"target_user_id",
"redirect_to",
"resolved_email",
"registration_password_hash",
"upstream_identity_claims",
"local_flow_state",
"browser_session_key",
"completion_code_hash",
"completion_code_expires_at",
"email_verified_at",
"password_verified_at",
"totp_verified_at",
"expires_at",
"consumed_at",
)
adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision")
requireSchemaFields(t, adoptionDecision,
"pending_auth_session_id",
"identity_id",
"adopt_display_name",
"adopt_avatar",
"decided_at",
)
requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id")
userSchema := requireSchema(t, schemas, "User")
requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
}
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
t.Helper()
schema, ok := schemas[name]
require.True(t, ok, "schema %s should exist", name)
return schema
}
func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
t.Helper()
fields := map[string]struct{}{}
for _, field := range schema.Fields {
fields[field.Name] = struct{}{}
}
for _, name := range names {
_, ok := fields[name]
require.True(t, ok, "schema %s should include field %s", schema.Name, name)
}
}
func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
t.Helper()
for _, index := range schema.Indexes {
if !index.Unique {
continue
}
if len(index.Fields) != len(fields) {
continue
}
match := true
for i := range fields {
if index.Fields[i] != fields[i] {
match = false
break
}
}
if match {
return
}
}
require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields)
}

View File

@ -0,0 +1,70 @@
package schema
import (
"time"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow.
type IdentityAdoptionDecision struct {
ent.Schema
}
func (IdentityAdoptionDecision) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "identity_adoption_decisions"},
}
}
func (IdentityAdoptionDecision) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (IdentityAdoptionDecision) Fields() []ent.Field {
return []ent.Field{
field.Int64("pending_auth_session_id"),
field.Int64("identity_id").
Optional().
Nillable(),
field.Bool("adopt_display_name").
Default(false),
field.Bool("adopt_avatar").
Default(false),
field.Time("decided_at").
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (IdentityAdoptionDecision) Edges() []ent.Edge {
return []ent.Edge{
edge.From("pending_auth_session", PendingAuthSession.Type).
Ref("adoption_decision").
Field("pending_auth_session_id").
Required().
Unique(),
edge.From("identity", AuthIdentity.Type).
Ref("adoption_decisions").
Field("identity_id").
Unique(),
}
}
func (IdentityAdoptionDecision) Indexes() []ent.Index {
return []ent.Index{
index.Fields("pending_auth_session_id").Unique(),
index.Fields("identity_id"),
}
}

View File

@ -0,0 +1,134 @@
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
var pendingAuthIntents = map[string]struct{}{
"login": {},
"bind_current_user": {},
"adopt_existing_user_by_email": {},
}
func validatePendingAuthIntent(value string) error {
if _, ok := pendingAuthIntents[value]; ok {
return nil
}
return fmt.Errorf("invalid pending auth intent %q", value)
}
// PendingAuthSession stores a short-lived post-auth decision session.
type PendingAuthSession struct {
ent.Schema
}
func (PendingAuthSession) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "pending_auth_sessions"},
}
}
func (PendingAuthSession) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (PendingAuthSession) Fields() []ent.Field {
return []ent.Field{
field.String("session_token").
MaxLen(255).
NotEmpty(),
field.String("intent").
MaxLen(40).
NotEmpty().
Validate(validatePendingAuthIntent),
field.String("provider_type").
MaxLen(20).
NotEmpty().
Validate(validateAuthProviderType),
field.String("provider_key").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("provider_subject").
NotEmpty().
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Int64("target_user_id").
Optional().
Nillable(),
field.String("redirect_to").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("resolved_email").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("registration_password_hash").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.JSON("upstream_identity_claims", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
field.JSON("local_flow_state", map[string]any{}).
Default(func() map[string]any { return map[string]any{} }).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
field.String("browser_session_key").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.String("completion_code_hash").
Default("").
SchemaType(map[string]string{dialect.Postgres: "text"}),
field.Time("completion_code_expires_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("email_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("password_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("totp_verified_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("expires_at").
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("consumed_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (PendingAuthSession) Edges() []ent.Edge {
return []ent.Edge{
edge.From("target_user", User.Type).
Ref("pending_auth_sessions").
Field("target_user_id").
Unique(),
edge.To("adoption_decision", IdentityAdoptionDecision.Type).
Unique(),
}
}
func (PendingAuthSession) Indexes() []ent.Index {
return []ent.Index{
index.Fields("session_token").Unique(),
index.Fields("target_user_id"),
index.Fields("expires_at"),
index.Fields("provider_type", "provider_key", "provider_subject"),
index.Fields("completion_code_hash"),
}
}

View File

@ -72,6 +72,17 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
field.String("signup_source").
MaxLen(20).
Default("email"),
field.Time("last_login_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("last_active_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知
field.Bool("balance_notify_enabled").
@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
edge.To("auth_identities", AuthIdentity.Type),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}

View File

@ -24,18 +24,26 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
// AuthIdentity is the client for interacting with the AuthIdentity builders.
AuthIdentity *AuthIdentityClient
// AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
// IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
// PendingAuthSession is the client for interacting with the PendingAuthSession builders.
PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@ -202,12 +210,16 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.AuthIdentity = NewAuthIdentityClient(tx.config)
tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)

View File

@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// SignupSource holds the value of the "signup_source" field.
SignupSource string `json:"signup_source,omitempty"`
// LastLoginAt holds the value of the "last_login_at" field.
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
// LastActiveAt holds the value of the "last_active_at" field.
LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
@ -83,11 +89,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
// AuthIdentities holds the value of the auth_identities edge.
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [11]bool
loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"}
}
// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
if e.loadedTypes[10] {
return e.AuthIdentities, nil
}
return nil, &NotLoadedError{edge: "auth_identities"}
}
// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
if e.loadedTypes[11] {
return e.PendingAuthSessions, nil
}
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[10] {
if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
case user.FieldSignupSource:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field signup_source", values[i])
} else if value.Valid {
_m.SignupSource = value.String
}
case user.FieldLastLoginAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
} else if value.Valid {
_m.LastLoginAt = new(time.Time)
*_m.LastLoginAt = value.Time
}
case user.FieldLastActiveAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
} else if value.Valid {
_m.LastActiveAt = new(time.Time)
*_m.LastActiveAt = value.Time
}
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m)
}
// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
return NewUserClient(_m.config).QueryAuthIdentities(_m)
}
// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@ -482,6 +540,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("signup_source=")
builder.WriteString(_m.SignupSource)
builder.WriteString(", ")
if v := _m.LastLoginAt; v != nil {
builder.WriteString("last_login_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.LastActiveAt; v != nil {
builder.WriteString("last_active_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")

View File

@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// FieldSignupSource holds the string denoting the signup_source field in the database.
FieldSignupSource = "signup_source"
// FieldLastLoginAt holds the string denoting the last_login_at field in the database.
FieldLastLoginAt = "last_login_at"
// FieldLastActiveAt holds the string denoting the last_active_at field in the database.
FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
@ -73,6 +79,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders"
// EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
EdgeAuthIdentities = "auth_identities"
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@ -145,6 +155,20 @@ const (
PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id"
// AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
AuthIdentitiesTable = "auth_identities"
// AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
// It exists in this package in order to avoid circular dependency with the "authidentity" package.
AuthIdentitiesInverseTable = "auth_identities"
// AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
AuthIdentitiesColumn = "user_id"
// PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
PendingAuthSessionsTable = "pending_auth_sessions"
// PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
// It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
PendingAuthSessionsInverseTable = "pending_auth_sessions"
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@ -171,6 +195,9 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
FieldSignupSource,
FieldLastLoginAt,
FieldLastActiveAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
@ -232,6 +259,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
// DefaultSignupSource holds the default value on creation for the "signup_source" field.
DefaultSignupSource string
// SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// BySignupSource orders the results by the signup_source field.
func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
}
// ByLastLoginAt orders the results by the last_login_at field.
func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
}
// ByLastActiveAt orders the results by the last_active_at field.
func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
}
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
// ByAuthIdentitiesCount orders the results by auth_identities count.
func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
}
}
// ByAuthIdentities orders the results by auth_identities terms.
func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
}
}
// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
)
}
func newAuthIdentitiesStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
)
}
func newPendingAuthSessionsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
func SignupSource(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
func LastLoginAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
func LastActiveAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
}
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
func SignupSourceEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
func SignupSourceNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
}
// SignupSourceIn applies the In predicate on the "signup_source" field.
func SignupSourceIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
}
// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
func SignupSourceNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
}
// SignupSourceGT applies the GT predicate on the "signup_source" field.
func SignupSourceGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldSignupSource, v))
}
// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
func SignupSourceGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldSignupSource, v))
}
// SignupSourceLT applies the LT predicate on the "signup_source" field.
func SignupSourceLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldSignupSource, v))
}
// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
func SignupSourceLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldSignupSource, v))
}
// SignupSourceContains applies the Contains predicate on the "signup_source" field.
func SignupSourceContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldSignupSource, v))
}
// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
func SignupSourceHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
}
// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
func SignupSourceHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
}
// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
func SignupSourceEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
}
// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
func SignupSourceContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
}
// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
func LastLoginAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
func LastLoginAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
}
// LastLoginAtIn applies the In predicate on the "last_login_at" field.
func LastLoginAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
}
// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
func LastLoginAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
}
// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
func LastLoginAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
}
// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
func LastLoginAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
}
// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
func LastLoginAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
}
// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
func LastLoginAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
}
// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
func LastLoginAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
}
// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
func LastLoginAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
}
// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
func LastActiveAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
}
// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
func LastActiveAtNEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
}
// LastActiveAtIn applies the In predicate on the "last_active_at" field.
func LastActiveAtIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
}
// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
func LastActiveAtNotIn(vs ...time.Time) predicate.User {
return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
}
// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
func LastActiveAtGT(v time.Time) predicate.User {
return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
}
// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
func LastActiveAtGTE(v time.Time) predicate.User {
return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
}
// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
func LastActiveAtLT(v time.Time) predicate.User {
return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
}
// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
func LastActiveAtLTE(v time.Time) predicate.User {
return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
}
// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
func LastActiveAtIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
}
// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
func LastActiveAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
}
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
})
}
// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
func HasAuthIdentities() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newAuthIdentitiesStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
func HasPendingAuthSessions() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPendingAuthSessionsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
// SetSignupSource sets the "signup_source" field.
func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
_c.mutation.SetSignupSource(v)
return _c
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
if v != nil {
_c.SetSignupSource(*v)
}
return _c
}
// SetLastLoginAt sets the "last_login_at" field.
func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
_c.mutation.SetLastLoginAt(v)
return _c
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetLastLoginAt(*v)
}
return _c
}
// SetLastActiveAt sets the "last_active_at" field.
func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
_c.mutation.SetLastActiveAt(v)
return _c
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
if v != nil {
_c.SetLastActiveAt(*v)
}
return _c
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...)
}
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
_c.mutation.AddAuthIdentityIDs(ids...)
return _c
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
_c.mutation.AddPendingAuthSessionIDs(ids...)
return _c
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
if _, ok := _c.mutation.SignupSource(); !ok {
v := user.DefaultSignupSource
_c.mutation.SetSignupSource(v)
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
@ -589,6 +667,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
if _, ok := _c.mutation.SignupSource(); !ok {
return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
}
if v, ok := _c.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if value, ok := _c.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
_node.SignupSource = value
}
if value, ok := _c.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
_node.LastLoginAt = &value
}
if value, ok := _c.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
_node.LastActiveAt = &value
}
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}
@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
u.Set(user.FieldSignupSource, v)
return u
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
u.SetExcluded(user.FieldSignupSource)
return u
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
u.Set(user.FieldLastLoginAt, v)
return u
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
u.SetExcluded(user.FieldLastLoginAt)
return u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
u.SetNull(user.FieldLastLoginAt)
return u
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
u.Set(user.FieldLastActiveAt, v)
return u
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
u.SetExcluded(user.FieldLastActiveAt)
return u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
u.SetNull(user.FieldLastActiveAt)
return u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetSignupSource(v)
})
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateSignupSource()
})
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetLastLoginAt(v)
})
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateLastLoginAt()
})
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearLastLoginAt()
})
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetLastActiveAt(v)
})
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateLastActiveAt()
})
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearLastActiveAt()
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
// SetSignupSource sets the "signup_source" field.
func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetSignupSource(v)
})
}
// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateSignupSource()
})
}
// SetLastLoginAt sets the "last_login_at" field.
func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetLastLoginAt(v)
})
}
// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateLastLoginAt()
})
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearLastLoginAt()
})
}
// SetLastActiveAt sets the "last_active_at" field.
func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetLastActiveAt(v)
})
}
// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateLastActiveAt()
})
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearLastActiveAt()
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {

View File

@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery
withAuthIdentities *AuthIdentityQuery
withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query
}
// QueryAuthIdentities chains the current query on the "auth_identities" edge.
func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(authidentity.Table, authidentity.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(),
withAuthIdentities: _q.withAuthIdentities.Clone(),
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q
}
// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
query := (&AuthIdentityClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withAuthIdentities = query
return _q
}
// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
query := (&PendingAuthSessionClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPendingAuthSessions = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [11]bool{
loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil,
_q.withAuthIdentities != nil,
_q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withAuthIdentities; query != nil {
if err := _q.loadAuthIdentities(ctx, query, nodes,
func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
return nil, err
}
}
if query := _q.withPendingAuthSessions; query != nil {
if err := _q.loadPendingAuthSessions(ctx, query, nodes,
func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
func(n *User, e *PendingAuthSession) {
n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
}); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
}
return nil
}
func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(authidentity.FieldUserID)
}
query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
}
query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.TargetUserID
if fk == nil {
return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)

View File

@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
// SetSignupSource sets the "signup_source" field.
func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
_u.mutation.SetSignupSource(v)
return _u
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
if v != nil {
_u.SetSignupSource(*v)
}
return _u
}
// SetLastLoginAt sets the "last_login_at" field.
func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
_u.mutation.SetLastLoginAt(v)
return _u
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetLastLoginAt(*v)
}
return _u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
_u.mutation.ClearLastLoginAt()
return _u
}
// SetLastActiveAt sets the "last_active_at" field.
func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
_u.mutation.SetLastActiveAt(v)
return _u
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
if v != nil {
_u.SetLastActiveAt(*v)
}
return _u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
_u.mutation.ClearLastActiveAt()
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...)
}
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAuthIdentityIDs(ids...)
return _u
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPendingAuthSessionIDs(ids...)
return _u
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...)
}
// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
_u.mutation.ClearAuthIdentities()
return _u
}
// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
_u.mutation.RemoveAuthIdentityIDs(ids...)
return _u
}
// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAuthIdentityIDs(ids...)
}
// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
_u.mutation.ClearPendingAuthSessions()
return _u
}
// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePendingAuthSessionIDs(ids...)
return _u
}
// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePendingAuthSessionIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := _u.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
return nil
}
@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
if value, ok := _u.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
if _u.mutation.LastLoginAtCleared() {
_spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
if value, ok := _u.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
}
if _u.mutation.LastActiveAtCleared() {
_spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
// SetSignupSource sets the "signup_source" field.
func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
_u.mutation.SetSignupSource(v)
return _u
}
// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
if v != nil {
_u.SetSignupSource(*v)
}
return _u
}
// SetLastLoginAt sets the "last_login_at" field.
func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
_u.mutation.SetLastLoginAt(v)
return _u
}
// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetLastLoginAt(*v)
}
return _u
}
// ClearLastLoginAt clears the value of the "last_login_at" field.
func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
_u.mutation.ClearLastLoginAt()
return _u
}
// SetLastActiveAt sets the "last_active_at" field.
func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
_u.mutation.SetLastActiveAt(v)
return _u
}
// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
if v != nil {
_u.SetLastActiveAt(*v)
}
return _u
}
// ClearLastActiveAt clears the value of the "last_active_at" field.
func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
_u.mutation.ClearLastActiveAt()
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...)
}
// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAuthIdentityIDs(ids...)
return _u
}
// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddAuthIdentityIDs(ids...)
}
// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPendingAuthSessionIDs(ids...)
return _u
}
// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPendingAuthSessionIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...)
}
// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
_u.mutation.ClearAuthIdentities()
return _u
}
// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemoveAuthIdentityIDs(ids...)
return _u
}
// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveAuthIdentityIDs(ids...)
}
// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
_u.mutation.ClearPendingAuthSessions()
return _u
}
// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePendingAuthSessionIDs(ids...)
return _u
}
// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePendingAuthSessionIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := _u.mutation.SignupSource(); ok {
if err := user.SignupSourceValidator(v); err != nil {
return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
}
}
return nil
}
@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.SignupSource(); ok {
_spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
if value, ok := _u.mutation.LastLoginAt(); ok {
_spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
if _u.mutation.LastLoginAtCleared() {
_spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
if value, ok := _u.mutation.LastActiveAt(); ok {
_spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
}
if _u.mutation.LastActiveAtCleared() {
_spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.AuthIdentitiesTable,
Columns: []string{user.AuthIdentitiesColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PendingAuthSessionsTable,
Columns: []string{user.PendingAuthSessionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@ -1608,6 +1608,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.LinuxDo.Enabled {
if !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
@ -1629,9 +1632,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
@ -1663,6 +1663,12 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.OIDC.Enabled {
if !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true")
}
if !c.OIDC.ValidateIDToken {
return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}
@ -1685,9 +1691,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")

View File

@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{}
}
response.Success(c, dto.SystemSettings{
payload := dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
})
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// UpdateSettingsRequest 更新设置请求
@ -276,9 +282,30 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
if !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled")
return
}
if !req.OIDCConnectValidateIDToken {
response.BadRequest(c, "OIDC ID Token validation must be enabled")
return
}
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
return
}
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectValidateIDToken {
if req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
if req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
},
LinuxDo: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
},
OIDC: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
},
WeChat: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{}
}
response.Success(c, dto.SystemSettings{
payload := dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
})
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
if input == nil {
return nil
}
normalized := normalizeDefaultSubscriptions(*input)
return &normalized
}
func float64ValueOrDefault(value *float64, fallback float64) float64 {
if value == nil {
return fallback
}
return *value
}
func intValueOrDefault(value *int, fallback int) int {
if value == nil {
return fallback
}
return *value
}
func boolValueOrDefault(value *bool, fallback bool) bool {
if value == nil {
return fallback
}
return *value
}
func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
if input == nil {
return fallback
}
result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
for _, item := range *input {
result = append(result, service.DefaultSubscriptionSetting{
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
})
}
return result
}
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
data := make(map[string]any)
raw, err := json.Marshal(settings)
if err == nil {
_ = json.Unmarshal(raw, &data)
}
if authSourceDefaults == nil {
authSourceDefaults = &service.AuthSourceDefaultSettings{}
}
data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false

View File

@ -0,0 +1,149 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type settingHandlerRepoStub struct {
values map[string]string
lastUpdates map[string]string
}
func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.lastUpdates = make(map[string]string, len(settings))
for key, value := range settings {
s.lastUpdates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for key, value := range s.values {
out[key] = value
}
return out, nil
}
func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
handler.GetSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, 9.5, data["auth_source_default_email_balance"])
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
require.True(t, ok)
require.Len(t, subscriptions, 1)
}
func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
"promo_code_enabled": true,
"auth_source_default_email_balance": 12.75,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, 12.75, data["auth_source_default_email_balance"])
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
}

View File

@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
ProviderKey: "linuxdo",
ProviderSubject: subject,
},
TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)

View File

@ -1,10 +1,21 @@
package handler
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-subject-1").
SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "LinuxDo Display",
"suggested_avatar_url": "https://cdn.example/linuxdo.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptAvatar: true,
})
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "LinuxDo Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}

View File

@ -1,10 +1,17 @@
package handler
import (
"context"
"errors"
"io"
"net/http"
"net/url"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@ -26,6 +33,7 @@ const (
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
TargetUserID *int64
ResolvedEmail string
RedirectTo string
BrowserSessionKey string
@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct {
CompletionResponse map[string]any
}
type oauthAdoptionDecisionRequest struct {
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
Intent: strings.TrimSpace(payload.Intent),
Identity: payload.Identity,
TargetUserID: payload.TargetUserID,
ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
RedirectTo: strings.TrimSpace(payload.RedirectTo),
BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput {
input := service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
}
if r.AdoptDisplayName != nil {
input.AdoptDisplayName = *r.AdoptDisplayName
}
if r.AdoptAvatar != nil {
input.AdoptAvatar = *r.AdoptAvatar
}
return input
}
func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
var req oauthAdoptionDecisionRequest
if c == nil || c.Request == nil || c.Request.Body == nil {
return req, nil
}
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, io.EOF) {
return req, nil
}
return req, err
}
return req, nil
}
func persistPendingOAuthAdoptionDecision(
c *gin.Context,
svc *service.AuthPendingIdentityService,
sessionID int64,
req oauthAdoptionDecisionRequest,
) error {
if !req.hasDecision() {
return nil
}
if svc == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil {
return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return nil
}
func cloneOAuthMetadata(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
func normalizeAdoptedOAuthDisplayName(value string) string {
value = strings.TrimSpace(value)
if len([]rune(value)) > 100 {
value = string([]rune(value)[:100])
}
return value
}
func (h *AuthHandler) entClient() *dbent.Client {
if h == nil || h.authService == nil {
return nil
}
return h.authService.EntClient()
}
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
req oauthAdoptionDecisionRequest,
) (*dbent.IdentityAdoptionDecision, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
existing, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
Only(c.Request.Context())
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
}
if existing != nil && !req.hasDecision() {
return existing, nil
}
if existing == nil && !req.hasDecision() {
return nil, nil
}
input := service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
}
if existing != nil {
input.AdoptDisplayName = existing.AdoptDisplayName
input.AdoptAvatar = existing.AdoptAvatar
input.IdentityID = existing.IdentityID
}
if req.AdoptDisplayName != nil {
input.AdoptDisplayName = *req.AdoptDisplayName
}
if req.AdoptAvatar != nil {
input.AdoptAvatar = *req.AdoptAvatar
}
svc, err := h.pendingIdentityService()
if err != nil {
return nil, err
}
decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return decision, nil
}
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
}
if session.TargetUserID != nil && *session.TargetUserID > 0 {
return *session.TargetUserID, nil
}
email := strings.TrimSpace(session.ResolvedEmail)
if email == "" {
return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
}
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(email)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
}
return 0, err
}
return userEntity.ID, nil
}
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil {
return nil
}
switch strings.TrimSpace(session.ProviderType) {
case "oidc":
issuer := strings.TrimSpace(session.ProviderKey)
if issuer == "" {
issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
}
if issuer == "" {
return nil
}
return &issuer
default:
issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
if issuer == "" {
return nil
}
return &issuer
}
}
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
client := tx.Client()
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if identity != nil {
if identity.UserID != userID {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return identity, nil
}
create := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(strings.TrimSpace(session.ProviderType)).
SetProviderKey(strings.TrimSpace(session.ProviderKey)).
SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
if issuer := oauthIdentityIssuer(session); issuer != nil {
create = create.SetIssuer(strings.TrimSpace(*issuer))
}
return create.Save(ctx)
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
) error {
if client == nil || session == nil || decision == nil {
return nil
}
if !decision.AdoptDisplayName && !decision.AdoptAvatar {
return nil
}
targetUserID := int64(0)
if overrideUserID != nil && *overrideUserID > 0 {
targetUserID = *overrideUserID
} else {
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
if err != nil {
return err
}
targetUserID = resolvedUserID
}
adoptedDisplayName := ""
if decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
if decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
if decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
return err
}
}
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
if err != nil {
return err
}
metadata := cloneOAuthMetadata(identity.Metadata)
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
if decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
if decision.AdoptAvatar && adoptedAvatarURL != "" {
metadata["avatar_url"] = adoptedAvatarURL
}
updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
}
if _, err := updateIdentity.Save(ctx); err != nil {
return err
}
if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
return err
}
}
return tx.Commit()
}
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
if err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
if err != nil {
response.ErrorFrom(c, err)
return
}
_ = decision
}
response.Success(c, payload)
return
}
if !adoptionDecision.hasDecision() {
response.Success(c, payload)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearCookies()

View File

@ -1,9 +1,30 @@
package handler
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
require.Equal(t, true, payload["adoption_required"])
}
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("linuxdo-123@linuxdo-connect.invalid").
SetUsername("legacy-name").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("pending-session-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("123").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "Alice Example",
"suggested_avatar_url": "https://cdn.example/alice.png",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"redirect": "/dashboard",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
previewRecorder := httptest.NewRecorder()
previewCtx, _ := gin.CreateTestContext(previewRecorder)
previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
previewCtx.Request = previewReq
handler.ExchangePendingOAuthCompletion(previewCtx)
require.Equal(t, http.StatusOK, previewRecorder.Code)
previewData := decodeJSONResponseData(t, previewRecorder)
require.Equal(t, "Alice Example", previewData["suggested_display_name"])
require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
require.Equal(t, true, previewData["adoption_required"])
storedUser, err := client.User.Get(ctx, userEntity.ID)
require.NoError(t, err)
require.Equal(t, "legacy-name", storedUser.Username)
previewSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.Nil(t, previewSession.ConsumedAt)
body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
finalizeRecorder := httptest.NewRecorder()
finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
finalizeReq.Header.Set("Content-Type", "application/json")
finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
finalizeCtx.Request = finalizeReq
handler.ExchangePendingOAuthCompletion(finalizeCtx)
require.Equal(t, http.StatusOK, finalizeRecorder.Code)
storedUser, err = client.User.Get(ctx, userEntity.ID)
require.NoError(t, err)
require.Equal(t, "Alice Example", storedUser.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
}
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
},
}, cfg)
authSvc := service.NewAuthService(
client,
&oauthPendingFlowUserRepo{client: client},
nil,
&oauthPendingFlowRefreshTokenCacheStub{},
cfg,
settingSvc,
nil,
nil,
nil,
nil,
nil,
)
return &AuthHandler{
authService: authSvc,
settingSvc: settingSvc,
}, client
}
func boolSettingValue(v bool) string {
if v {
return "true"
}
return "false"
}
func boolPtr(v bool) *bool {
return &v
}
type oauthPendingFlowSettingRepoStub struct {
values map[string]string
}
func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
return nil, service.ErrSettingNotFound
}
func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
value, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
return nil
}
func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
result[key] = value
}
}
return result, nil
}
func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
result := make(map[string]string, len(s.values))
for key, value := range s.values {
result[key] = value
}
return result, nil
}
func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
return nil
}
type oauthPendingFlowRefreshTokenCacheStub struct{}
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
var envelope struct {
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
return envelope.Data
}
func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
var payload map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
return payload
}
type oauthPendingFlowUserRepo struct {
client *dbent.Client
}
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
entity, err := r.client.User.Create().
SetEmail(user.Email).
SetUsername(user.Username).
SetNotes(user.Notes).
SetPasswordHash(user.PasswordHash).
SetRole(user.Role).
SetBalance(user.Balance).
SetConcurrency(user.Concurrency).
SetStatus(user.Status).
SetSignupSource(user.SignupSource).
SetNillableLastLoginAt(user.LastLoginAt).
SetNillableLastActiveAt(user.LastActiveAt).
Save(ctx)
if err != nil {
return err
}
user.ID = entity.ID
user.CreatedAt = entity.CreatedAt
user.UpdatedAt = entity.UpdatedAt
return nil
}
func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
entity, err := r.client.User.Get(ctx, id)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrUserNotFound
}
return nil, err
}
return oauthPendingFlowServiceUser(entity), nil
}
func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrUserNotFound
}
return nil, err
}
return oauthPendingFlowServiceUser(entity), nil
}
func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
entity, err := r.client.User.UpdateOneID(user.ID).
SetEmail(user.Email).
SetUsername(user.Username).
SetNotes(user.Notes).
SetPasswordHash(user.PasswordHash).
SetRole(user.Role).
SetBalance(user.Balance).
SetConcurrency(user.Concurrency).
SetStatus(user.Status).
SetSignupSource(user.SignupSource).
SetNillableLastLoginAt(user.LastLoginAt).
SetNillableLastActiveAt(user.LastActiveAt).
Save(ctx)
if err != nil {
return err
}
user.UpdatedAt = entity.UpdatedAt
return nil
}
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
return r.client.User.DeleteOneID(id).Exec(ctx)
}
func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
return nil, service.ErrUserNotFound
}
func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
return nil
}
func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
panic("unexpected UpdateBalance call")
}
func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
panic("unexpected DeductBalance call")
}
func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected UpdateConcurrency call")
}
func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
return count > 0, err
}
func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected RemoveGroupFromUserAllowedGroups call")
}
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error {
panic("unexpected EnableTotp call")
}
func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error {
panic("unexpected DisableTotp call")
}
func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
if entity == nil {
return nil
}
return &service.User{
ID: entity.ID,
Email: entity.Email,
Username: entity.Username,
Notes: entity.Notes,
PasswordHash: entity.PasswordHash,
Role: entity.Role,
Balance: entity.Balance,
Concurrency: entity.Concurrency,
Status: entity.Status,
SignupSource: entity.SignupSource,
LastLoginAt: entity.LastLoginAt,
LastActiveAt: entity.LastActiveAt,
CreatedAt: entity.CreatedAt,
UpdatedAt: entity.UpdatedAt,
}
}

View File

@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderKey: issuer,
ProviderSubject: subject,
},
TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)

View File

@ -1,6 +1,7 @@
package handler
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
@ -12,7 +13,13 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
E: e,
}
}
func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-subject-1").
SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
SetBrowserSessionKey("oidc-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
"suggested_display_name": "OIDC Display",
"suggested_avatar_url": "https://cdn.example/oidc.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptAvatar: true,
})
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "OIDC Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example.com"),
authidentity.ProviderSubjectEQ("oidc-subject-1"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}

View File

@ -0,0 +1,618 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
wechatOAuthCookieMaxAgeSec = 10 * 60
wechatOAuthStateCookieName = "wechat_oauth_state"
wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
wechatOAuthIntentCookieName = "wechat_oauth_intent"
wechatOAuthModeCookieName = "wechat_oauth_mode"
wechatOAuthDefaultRedirectTo = "/dashboard"
wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
wechatOAuthProviderKey = "wechat-main"
wechatOAuthIntentLogin = "login"
wechatOAuthIntentBind = "bind_current_user"
wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
)
var (
wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
)
type wechatOAuthConfig struct {
mode string
appID string
appSecret string
authorizeURL string
scope string
redirectURI string
frontendCallback string
}
type wechatOAuthTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
OpenID string `json:"openid"`
Scope string `json:"scope"`
UnionID string `json:"unionid"`
ErrCode int64 `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
type wechatOAuthUserInfoResponse struct {
OpenID string `json:"openid"`
Nickname string `json:"nickname"`
HeadImgURL string `json:"headimgurl"`
UnionID string `json:"unionid"`
ErrCode int64 `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
// browser cookies required by the rebuild pending-auth bridge.
func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = wechatOAuthDefaultRedirectTo
}
browserSessionKey, err := generateOAuthPendingBrowserSession()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
return
}
intent := normalizeWeChatOAuthIntent(c.Query("intent"))
secureCookie := isRequestHTTPS(c)
wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
clearOAuthPendingSessionCookie(c, secureCookie)
authURL, err := buildWeChatAuthorizeURL(cfg, state)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
// and stores the result in the unified pending-auth flow.
func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
frontendCallback := wechatOAuthFrontendCallback()
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = wechatOAuthDefaultRedirectTo
}
browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
if strings.TrimSpace(browserSessionKey) == "" {
redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
return
}
intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
if err != nil || strings.TrimSpace(mode) == "" {
redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
return
}
cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
if err != nil {
redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
if err != nil {
redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
return
}
unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
providerSubject := firstNonEmpty(unionid, openid)
if providerSubject == "" {
redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "")
return
}
username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
email := wechatSyntheticEmail(providerSubject)
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": providerSubject,
"openid": openid,
"unionid": unionid,
"mode": cfg.mode,
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
}
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
// validating the invitation code and consuming the current pending browser session.
// POST /api/v1/auth/oauth/wechat/complete-registration
func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
var req completeWeChatOAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
return
}
secureCookie := isRequestHTTPS(c)
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
return
}
pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
if email == "" || username == "" {
response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return
}
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) createWeChatPendingSession(
c *gin.Context,
intent string,
providerSubject string,
email string,
redirectTo string,
browserSessionKey string,
upstreamClaims map[string]any,
tokenPair *service.TokenPair,
authErr error,
) error {
completionResponse := map[string]any{
"redirect": redirectTo,
}
if authErr != nil {
if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
completionResponse["error"] = "invitation_required"
} else {
return authErr
}
} else if tokenPair != nil {
completionResponse["access_token"] = tokenPair.AccessToken
completionResponse["refresh_token"] = tokenPair.RefreshToken
completionResponse["expires_in"] = tokenPair.ExpiresIn
completionResponse["token_type"] = "Bearer"
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: intent,
Identity: service.PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: wechatOAuthProviderKey,
ProviderSubject: providerSubject,
},
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: completionResponse,
})
}
func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
mode, err := resolveWeChatOAuthMode(rawMode, c)
if err != nil {
return wechatOAuthConfig{}, err
}
apiBaseURL := ""
if h != nil && h.settingSvc != nil {
settings, err := h.settingSvc.GetAllSettings(ctx)
if err == nil && settings != nil {
apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
}
}
cfg := wechatOAuthConfig{
mode: mode,
redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"),
frontendCallback: wechatOAuthFrontendCallback(),
}
switch mode {
case "mp":
cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
cfg.scope = "snsapi_userinfo"
default:
cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
cfg.scope = "snsapi_login"
}
if cfg.appID == "" || cfg.appSecret == "" {
return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
}
if strings.TrimSpace(cfg.redirectURI) == "" {
return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
}
return cfg, nil
}
func wechatOAuthFrontendCallback() string {
return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB)
}
func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
mode := strings.ToLower(strings.TrimSpace(rawMode))
if mode == "" {
if isWeChatBrowserRequest(c) {
return "mp", nil
}
return "open", nil
}
if mode != "open" && mode != "mp" {
return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
}
return mode, nil
}
func isWeChatBrowserRequest(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
}
func normalizeWeChatOAuthIntent(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "", "login":
return wechatOAuthIntentLogin
case "bind", "bind_current_user":
return wechatOAuthIntentBind
case "adopt", "adopt_existing_user_by_email":
return wechatOAuthIntentAdoptEmail
default:
return wechatOAuthIntentLogin
}
}
func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
u, err := url.Parse(cfg.authorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize url: %w", err)
}
query := u.Query()
query.Set("appid", cfg.appID)
query.Set("redirect_uri", cfg.redirectURI)
query.Set("response_type", "code")
query.Set("scope", cfg.scope)
query.Set("state", state)
u.RawQuery = query.Encode()
u.Fragment = "wechat_redirect"
return u.String(), nil
}
func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
callbackPath = strings.TrimSpace(callbackPath)
if callbackPath == "" {
return ""
}
if raw := strings.TrimSpace(apiBaseURL); raw != "" {
if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
basePath := strings.TrimRight(parsed.EscapedPath(), "/")
targetPath := callbackPath
if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
} else if basePath != "" {
targetPath = basePath + callbackPath
}
return parsed.Scheme + "://" + parsed.Host + targetPath
}
}
if c == nil || c.Request == nil {
return ""
}
scheme := "http"
if isRequestHTTPS(c) {
scheme = "https"
}
host := strings.TrimSpace(c.Request.Host)
if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
host = forwardedHost
}
if host == "" {
return ""
}
return scheme + "://" + host + callbackPath
}
func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
if err != nil {
return nil, nil, err
}
userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
if err != nil {
return nil, nil, err
}
return tokenResp, userInfo, nil
}
func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
if err != nil {
return nil, fmt.Errorf("parse wechat access token url: %w", err)
}
query := endpoint.Query()
query.Set("appid", cfg.appID)
query.Set("secret", cfg.appSecret)
query.Set("code", strings.TrimSpace(code))
query.Set("grant_type", "authorization_code")
endpoint.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
if err != nil {
return nil, fmt.Errorf("build wechat access token request: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request wechat access token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read wechat access token response: %w", err)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
}
var tokenResp wechatOAuthTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("decode wechat access token response: %w", err)
}
if tokenResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, fmt.Errorf("wechat access token missing access_token")
}
return &tokenResp, nil
}
func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
if tokenResp == nil {
return nil, fmt.Errorf("wechat token response is nil")
}
endpoint, err := url.Parse(wechatOAuthUserInfoURL)
if err != nil {
return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
}
query := endpoint.Query()
query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
query.Set("lang", "zh_CN")
endpoint.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
if err != nil {
return nil, fmt.Errorf("build wechat userinfo request: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request wechat userinfo: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read wechat userinfo response: %w", err)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
}
var userInfo wechatOAuthUserInfoResponse
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
}
if userInfo.ErrCode != 0 {
return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
}
return &userInfo, nil
}
func wechatSyntheticEmail(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return ""
}
return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
}
func wechatFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return "wechat_user"
}
return "wechat_" + truncateFragmentValue(subject)
}
func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: wechatOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func wechatClearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: wechatOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}

View File

@ -0,0 +1,411 @@
//go:build unit
package handler
import (
"bytes"
"context"
"database/sql"
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
c.Request.Host = "api.example.com"
handler := &AuthHandler{}
handler.WeChatOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotEmpty(t, location)
require.Contains(t, location, "open.weixin.qq.com")
require.Contains(t, location, "appid=wx-open-app")
require.Contains(t, location, "scope=snsapi_login")
cookies := recorder.Result().Cookies()
require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
}
func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
ctx := context.Background()
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "wechat", session.ProviderType)
require.Equal(t, "wechat-main", session.ProviderKey)
require.Equal(t, "union-456", session.ProviderSubject)
require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, true)
defer client.Close()
ctx := context.Background()
redeemRepo := repository.NewRedeemCodeRepository(client)
require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
Code: "invite-1",
Type: service.RedeemTypeInvitation,
Status: service.StatusUnused,
}))
callbackRecorder := httptest.NewRecorder()
callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
callbackReq.Host = "api.example.com"
callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
callbackCtx.Request = callbackReq
handler.WeChatOAuthCallback(callbackCtx)
require.Equal(t, http.StatusFound, callbackRecorder.Code)
require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
pendingSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"])
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
completeRecorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(completeRecorder)
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
completeReq.Header.Set("Content-Type", "application/json")
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
completeCtx.Request = completeReq
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "WeChat Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
userRepo := &oauthPendingFlowUserRepo{client: client}
redeemRepo := repository.NewRedeemCodeRepository(client)
settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
},
}, &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
})
authSvc := service.NewAuthService(
client,
userRepo,
redeemRepo,
&wechatOAuthRefreshTokenCacheStub{},
&config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
},
settingSvc,
nil,
nil,
nil,
nil,
nil,
)
return &AuthHandler{
authService: authSvc,
settingSvc: settingSvc,
}, client
}
func encodedCookie(name, value string) *http.Cookie {
return &http.Cookie{
Name: name,
Value: encodeCookieValue(value),
Path: "/",
}
}
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
for _, cookie := range cookies {
if cookie.Name == name {
return cookie
}
}
return nil
}
func decodeCookieValueForTest(t *testing.T, value string) string {
t.Helper()
raw, err := base64.RawURLEncoding.DecodeString(value)
require.NoError(t, err)
return string(raw)
}
type wechatOAuthSettingRepoStub struct {
values map[string]string
}
func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
return nil, service.ErrSettingNotFound
}
func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
value, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
return nil
}
func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
result[key] = value
}
}
return result, nil
}
func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
result := make(map[string]string, len(s.values))
for key, value := range s.values {
result[key] = value
}
return result, nil
}
func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
return nil
}
type wechatOAuthRefreshTokenCacheStub struct{}
func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}

View File

@ -189,6 +189,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`

View File

@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey {
case payment.TypeEasyPay:
case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody)
if err == nil {
return values.Get("out_trade_no")

View File

@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
func TestExtractOutTradeNo(t *testing.T) {
tests := []struct {
name string
providerKey string
rawBody string
want string
}{
{
name: "easypay query payload",
providerKey: "easypay",
rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
want: "sub2_123",
},
{
name: "alipay query payload",
providerKey: "alipay",
rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
want: "sub2_456",
},
{
name: "unknown provider",
providerKey: "wxpay",
rawBody: "{}",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
})
}
}

View File

@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,

View File

@ -34,10 +34,16 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
type userProfileResponse struct {
dto.User
AvatarURL string `json:"avatar_url,omitempty"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(userData))
response.Success(c, userProfileResponseFromService(userData))
}
// ChangePassword handles changing user password
@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
response.Success(c, userProfileResponseFromService(updatedUser))
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
response.Success(c, userProfileResponseFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
response.Success(c, userProfileResponseFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return
}
response.Success(c, dto.UserFromService(updatedUser))
response.Success(c, userProfileResponseFromService(updatedUser))
}
func userProfileResponseFromService(user *service.User) userProfileResponse {
base := dto.UserFromService(user)
if base == nil {
return userProfileResponse{}
}
return userProfileResponse{
User: *base,
AvatarURL: user.AvatarURL,
}
}

View File

@ -0,0 +1,136 @@
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userHandlerRepoStub struct {
user *service.User
}
func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
cloned := *user
s.user = &cloned
return nil
}
func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
if s.user == nil || s.user.AvatarURL == "" {
return nil, nil
}
return &service.UserAvatar{
StorageProvider: s.user.AvatarSource,
URL: s.user.AvatarURL,
ContentType: s.user.AvatarMIME,
ByteSize: s.user.AvatarByteSize,
SHA256: s.user.AvatarSHA256,
}, nil
}
func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
s.user.AvatarURL = input.URL
s.user.AvatarSource = input.StorageProvider
s.user.AvatarMIME = input.ContentType
s.user.AvatarByteSize = input.ByteSize
s.user.AvatarSHA256 = input.SHA256
return &service.UserAvatar{
StorageProvider: input.StorageProvider,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
s.user.AvatarURL = ""
s.user.AvatarSource = ""
s.user.AvatarMIME = ""
s.user.AvatarByteSize = 0
s.user.AvatarSHA256 = ""
return nil
}
func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "handler-avatar@example.com",
Username: "handler-avatar",
Role: service.RoleUser,
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.UpdateProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
AvatarURL string `json:"avatar_url"`
Username string `json:"username"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
require.Equal(t, "handler-avatar", resp.Data.Username)
}

View File

@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
user.FieldSignupSource,
user.FieldLastLoginAt,
user.FieldLastActiveAt,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SignupSource: u.SignupSource,
LastLoginAt: u.LastLoginAt,
LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,

View File

@ -0,0 +1,148 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type AuthIdentityMigrationReport struct {
ID int64
ReportType string
ReportKey string
Details map[string]any
CreatedAt time.Time
}
type AuthIdentityMigrationReportQuery struct {
ReportType string
Limit int
Offset int
}
type AuthIdentityMigrationReportSummary struct {
Total int64
ByType map[string]int64
}
func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
limit := query.Limit
if limit <= 0 {
limit = 100
}
rows, err := exec.QueryContext(ctx, `
SELECT id, report_type, report_key, details, created_at
FROM auth_identity_migration_reports
WHERE ($1 = '' OR report_type = $1)
ORDER BY created_at DESC, id DESC
LIMIT $2 OFFSET $3`,
strings.TrimSpace(query.ReportType),
limit,
query.Offset,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
reports := make([]AuthIdentityMigrationReport, 0)
for rows.Next() {
report, scanErr := scanAuthIdentityMigrationReport(rows)
if scanErr != nil {
return nil, scanErr
}
reports = append(reports, report)
}
if err := rows.Err(); err != nil {
return nil, err
}
return reports, nil
}
func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
rows, err := exec.QueryContext(ctx, `
SELECT id, report_type, report_key, details, created_at
FROM auth_identity_migration_reports
WHERE report_type = $1 AND report_key = $2
LIMIT 1`,
strings.TrimSpace(reportType),
strings.TrimSpace(reportKey),
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, sql.ErrNoRows
}
report, err := scanAuthIdentityMigrationReport(rows)
if err != nil {
return nil, err
}
return &report, rows.Err()
}
func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
rows, err := exec.QueryContext(ctx, `
SELECT report_type, COUNT(*)
FROM auth_identity_migration_reports
GROUP BY report_type
ORDER BY report_type ASC`)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
summary := &AuthIdentityMigrationReportSummary{
ByType: make(map[string]int64),
}
for rows.Next() {
var reportType string
var count int64
if err := rows.Scan(&reportType, &count); err != nil {
return nil, err
}
summary.ByType[reportType] = count
summary.Total += count
}
if err := rows.Err(); err != nil {
return nil, err
}
return summary, nil
}
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
var (
report AuthIdentityMigrationReport
details []byte
)
if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
return AuthIdentityMigrationReport{}, err
}
report.Details = map[string]any{}
if len(details) > 0 {
if err := json.Unmarshal(details, &report.Details); err != nil {
return AuthIdentityMigrationReport{}, err
}
}
return report, nil
}

View File

@ -0,0 +1,544 @@
package repository
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"unsafe"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
var (
ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT",
"auth identity already belongs to another user",
)
ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
"auth identity channel already belongs to another user",
)
)
type ProviderGrantReason string
const (
ProviderGrantReasonSignup ProviderGrantReason = "signup"
ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
)
type AuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type AuthIdentityChannelKey struct {
ProviderType string
ProviderKey string
Channel string
ChannelAppID string
ChannelSubject string
}
type CreateAuthIdentityInput struct {
UserID int64
Canonical AuthIdentityKey
Channel *AuthIdentityChannelKey
Issuer *string
VerifiedAt *time.Time
Metadata map[string]any
ChannelMetadata map[string]any
}
type BindAuthIdentityInput = CreateAuthIdentityInput
type CreateAuthIdentityResult struct {
Identity *dbent.AuthIdentity
Channel *dbent.AuthIdentityChannel
}
func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
if r == nil || r.Identity == nil {
return AuthIdentityKey{}
}
return AuthIdentityKey{
ProviderType: r.Identity.ProviderType,
ProviderKey: r.Identity.ProviderKey,
ProviderSubject: r.Identity.ProviderSubject,
}
}
func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
if r == nil || r.Channel == nil {
return nil
}
return &AuthIdentityChannelKey{
ProviderType: r.Channel.ProviderType,
ProviderKey: r.Channel.ProviderKey,
Channel: r.Channel.Channel,
ChannelAppID: r.Channel.ChannelAppID,
ChannelSubject: r.Channel.ChannelSubject,
}
}
type UserAuthIdentityLookup struct {
User *dbent.User
Identity *dbent.AuthIdentity
Channel *dbent.AuthIdentityChannel
}
type ProviderGrantRecordInput struct {
UserID int64
ProviderType string
GrantReason ProviderGrantReason
}
type IdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type sqlQueryExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
if dbent.TxFromContext(ctx) != nil {
return fn(ctx)
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx); err != nil {
return err
}
return tx.Commit()
}
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
client := clientFromContext(ctx, r.client)
create := client.AuthIdentity.Create().
SetUserID(input.UserID).
SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
SetMetadata(copyMetadata(input.Metadata)).
SetNillableIssuer(input.Issuer).
SetNillableVerifiedAt(input.VerifiedAt)
identity, err := create.Save(ctx)
if err != nil {
return nil, err
}
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
SetChannel(strings.TrimSpace(input.Channel.Channel)).
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
SetMetadata(copyMetadata(input.ChannelMetadata)).
Save(ctx)
if err != nil {
return nil, err
}
}
return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
}
func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
).
WithUser().
Only(ctx)
if err != nil {
return nil, err
}
return &UserAuthIdentityLookup{
User: identity.Edges.User,
Identity: identity,
}, nil
}
func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
).
WithIdentity(func(q *dbent.AuthIdentityQuery) {
q.WithUser()
}).
Only(ctx)
if err != nil {
return nil, err
}
return &UserAuthIdentityLookup{
User: channel.Edges.Identity.Edges.User,
Identity: channel.Edges.Identity,
Channel: channel,
}, nil
}
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
var result *CreateAuthIdentityResult
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
canonical := input.Canonical
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
).
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
return err
}
if identity != nil && identity.UserID != input.UserID {
return ErrAuthIdentityOwnershipConflict
}
if identity == nil {
identity, err = client.AuthIdentity.Create().
SetUserID(input.UserID).
SetProviderType(strings.TrimSpace(canonical.ProviderType)).
SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
SetMetadata(copyMetadata(input.Metadata)).
SetNillableIssuer(input.Issuer).
SetNillableVerifiedAt(input.VerifiedAt).
Save(txCtx)
if err != nil {
return err
}
} else {
update := client.AuthIdentity.UpdateOneID(identity.ID)
if input.Metadata != nil {
update = update.SetMetadata(copyMetadata(input.Metadata))
}
if input.Issuer != nil {
update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
}
if input.VerifiedAt != nil {
update = update.SetVerifiedAt(*input.VerifiedAt)
}
identity, err = update.Save(txCtx)
if err != nil {
return err
}
}
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
).
WithIdentity().
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
return err
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
return ErrAuthIdentityChannelOwnershipConflict
}
if channel == nil {
channel, err = client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
SetChannel(strings.TrimSpace(input.Channel.Channel)).
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
SetMetadata(copyMetadata(input.ChannelMetadata)).
Save(txCtx)
if err != nil {
return err
}
} else {
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID)
if input.ChannelMetadata != nil {
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
}
channel, err = update.Save(txCtx)
if err != nil {
return err
}
}
}
result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return false, fmt.Errorf("sql executor is not configured")
}
result, err := exec.ExecContext(ctx, `
INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
input.UserID,
strings.TrimSpace(input.ProviderType),
string(input.GrantReason),
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
client := clientFromContext(ctx, r.client)
current, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
now := time.Now().UTC()
if current == nil {
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(now)
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
Only(ctx)
}
func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
SetLastLoginAt(loginAt).
Save(ctx)
return err
}
func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
SetLastActiveAt(activeAt).
Save(ctx)
return err
}
func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return nil, err
}
rows, err := exec.QueryContext(ctx, `
SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
FROM user_avatars
WHERE user_id = $1`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, rows.Err()
}
var avatar service.UserAvatar
if err := rows.Scan(
&avatar.StorageProvider,
&avatar.StorageKey,
&avatar.URL,
&avatar.ContentType,
&avatar.ByteSize,
&avatar.SHA256,
); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return &avatar, nil
}
func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return nil, err
}
_, err = exec.ExecContext(ctx, `
INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (user_id) DO UPDATE SET
storage_provider = EXCLUDED.storage_provider,
storage_key = EXCLUDED.storage_key,
url = EXCLUDED.url,
content_type = EXCLUDED.content_type,
byte_size = EXCLUDED.byte_size,
sha256 = EXCLUDED.sha256,
updated_at = NOW()`,
userID,
strings.TrimSpace(input.StorageProvider),
strings.TrimSpace(input.StorageKey),
strings.TrimSpace(input.URL),
strings.TrimSpace(input.ContentType),
input.ByteSize,
strings.TrimSpace(input.SHA256),
)
if err != nil {
return nil, err
}
return &service.UserAvatar{
StorageProvider: strings.TrimSpace(input.StorageProvider),
StorageKey: strings.TrimSpace(input.StorageKey),
URL: strings.TrimSpace(input.URL),
ContentType: strings.TrimSpace(input.ContentType),
ByteSize: input.ByteSize,
SHA256: strings.TrimSpace(input.SHA256),
}, nil
}
func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return err
}
_, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
return err
}
func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error {
if user == nil {
return nil
}
avatar, err := r.GetUserAvatar(ctx, user.ID)
if err != nil {
return err
}
if avatar == nil {
return nil
}
user.AvatarURL = avatar.URL
user.AvatarSource = avatar.StorageProvider
user.AvatarMIME = avatar.ContentType
user.AvatarByteSize = avatar.ByteSize
user.AvatarSHA256 = avatar.SHA256
return nil
}
func copyMetadata(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
if tx := dbent.TxFromContext(ctx); tx != nil {
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
return exec
}
}
if fallback != nil {
return fallback
}
return sqlExecutorFromEntClient(client)
}
func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
return exec, nil
}
func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
if client == nil {
return nil
}
clientValue := reflect.ValueOf(client).Elem()
configValue := clientValue.FieldByName("config")
driverValue := configValue.FieldByName("driver")
if !driverValue.IsValid() {
return nil
}
driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
exec, ok := driver.(sqlQueryExecutor)
if !ok {
return nil
}
return exec
}

View File

@ -0,0 +1,428 @@
//go:build integration
package repository
import (
"context"
"errors"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type UserProfileIdentityRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *userRepository
}
func TestUserProfileIdentityRepoSuite(t *testing.T) {
suite.Run(t, new(UserProfileIdentityRepoSuite))
}
func (s *UserProfileIdentityRepoSuite) SetupTest() {
s.ctx = context.Background()
s.client = testEntClient(s.T())
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
_, err := integrationDB.ExecContext(s.ctx, `
TRUNCATE TABLE
identity_adoption_decisions,
auth_identity_channels,
auth_identities,
pending_auth_sessions,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars
RESTART IDENTITY`)
s.Require().NoError(err)
}
func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
s.T().Helper()
user, err := s.client.User.Create().
SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
SetPasswordHash("test-password-hash").
SetRole("user").
SetStatus("active").
Save(s.ctx)
s.Require().NoError(err)
return user
}
func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
s.T().Helper()
session, err := s.client.PendingAuthSession.Create().
SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
SetIntent("bind_current_user").
SetProviderType(key.ProviderType).
SetProviderKey(key.ProviderKey).
SetProviderSubject(key.ProviderSubject).
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
SetLocalFlowState(map[string]any{"step": "pending"}).
Save(s.ctx)
s.Require().NoError(err)
return session
}
func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
user := s.mustCreateUser("canonical-channel")
verifiedAt := time.Now().UTC().Truncate(time.Second)
created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
Channel: "mp",
ChannelAppID: "wx-app",
ChannelSubject: "openid-123",
},
Issuer: stringPtr("https://issuer.example"),
VerifiedAt: &verifiedAt,
Metadata: map[string]any{"unionid": "union-123"},
ChannelMetadata: map[string]any{"openid": "openid-123"},
})
s.Require().NoError(err)
s.Require().NotNil(created.Identity)
s.Require().NotNil(created.Channel)
canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
s.Require().NoError(err)
s.Require().Equal(user.ID, canonical.User.ID)
s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
s.Require().NoError(err)
s.Require().Equal(user.ID, channel.User.ID)
s.Require().Equal(created.Identity.ID, channel.Identity.ID)
s.Require().Equal(created.Channel.ID, channel.Channel.ID)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
owner := s.mustCreateUser("owner")
other := s.mustCreateUser("other")
first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: owner.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
Metadata: map[string]any{"username": "first"},
ChannelMetadata: map[string]any{"scope": "read"},
})
s.Require().NoError(err)
second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: owner.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
Metadata: map[string]any{"username": "second"},
ChannelMetadata: map[string]any{"scope": "write"},
})
s.Require().NoError(err)
s.Require().Equal(first.Identity.ID, second.Identity.ID)
s.Require().Equal(first.Channel.ID, second.Channel.ID)
s.Require().Equal("second", second.Identity.Metadata["username"])
s.Require().Equal("write", second.Channel.Metadata["scope"])
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: other.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: other.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-2",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
user := s.mustCreateUser("tx-rollback")
expectedErr := errors.New("rollback")
err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
_, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-rollback",
},
})
s.Require().NoError(err)
inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "oidc",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().True(inserted)
return expectedErr
})
s.Require().ErrorIs(err, expectedErr)
_, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-rollback",
})
s.Require().True(dbent.IsNotFound(err))
var count int
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT COUNT(*)
FROM user_provider_default_grants
WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
user.ID,
"oidc",
string(ProviderGrantReasonFirstBind),
).Scan(&count))
s.Require().Zero(count)
}
func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
user := s.mustCreateUser("grant")
inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().True(inserted)
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().False(inserted)
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonSignup,
})
s.Require().NoError(err)
s.Require().True(inserted)
var count int
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT COUNT(*)
FROM user_provider_default_grants
WHERE user_id = $1 AND provider_type = $2`,
user.ID,
"wechat",
).Scan(&count))
s.Require().Equal(2, count)
}
func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
user := s.mustCreateUser("adoption")
identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
s.Require().NoError(err)
session := s.mustCreatePendingAuthSession(identity.IdentityRef())
first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
s.Require().NoError(err)
s.Require().True(first.AdoptDisplayName)
s.Require().False(first.AdoptAvatar)
s.Require().Nil(first.IdentityID)
second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.Identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
s.Require().NoError(err)
s.Require().Equal(first.ID, second.ID)
s.Require().NotNil(second.IdentityID)
s.Require().Equal(identity.Identity.ID, *second.IdentityID)
s.Require().True(second.AdoptAvatar)
loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
s.Require().NoError(err)
s.Require().Equal(second.ID, loaded.ID)
s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
}
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
user := s.mustCreateUser("avatar")
inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "inline",
URL: "data:image/png;base64,QUJD",
ContentType: "image/png",
ByteSize: 3,
SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
})
s.Require().NoError(err)
s.Require().Equal("inline", inlineAvatar.StorageProvider)
s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(loadedAvatar)
s.Require().Equal("image/png", loadedAvatar.ContentType)
s.Require().Equal(3, loadedAvatar.ByteSize)
_, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/avatar.png",
})
s.Require().NoError(err)
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(loadedAvatar)
s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
s.Require().Zero(loadedAvatar.ByteSize)
s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Nil(loadedAvatar)
}
func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() {
_, err := integrationDB.ExecContext(s.ctx, `
INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
VALUES
('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'),
('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'),
('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`)
s.Require().NoError(err)
summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx)
s.Require().NoError(err)
s.Require().Equal(int64(3), summary.Total)
s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"])
s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{
ReportType: "wechat_openid_only_requires_remediation",
Limit: 10,
})
s.Require().NoError(err)
s.Require().Len(reports, 2)
s.Require().Equal("u-2", reports[0].ReportKey)
s.Require().Equal(float64(2), reports[0].Details["user_id"])
report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3")
s.Require().NoError(err)
s.Require().Equal("u-3", report.ReportKey)
s.Require().Equal(float64(3), report.Details["user_id"])
}
func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
user := s.mustCreateUser("activity")
loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
activeAt := loginAt.Add(5 * time.Minute)
s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
var storedLoginAt sqlNullTime
var storedActiveAt sqlNullTime
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT last_login_at, last_active_at
FROM users
WHERE id = $1`,
user.ID,
).Scan(&storedLoginAt, &storedActiveAt))
s.Require().True(storedLoginAt.Valid)
s.Require().True(storedActiveAt.Valid)
s.Require().True(storedLoginAt.Time.Equal(loginAt))
s.Require().True(storedActiveAt.Time.Equal(activeAt))
}
type sqlNullTime struct {
Time time.Time
Valid bool
}
func (t *sqlNullTime) Scan(value any) error {
switch v := value.(type) {
case time.Time:
t.Time = v
t.Valid = true
return nil
case nil:
t.Time = time.Time{}
t.Valid = false
return nil
default:
return fmt.Errorf("unsupported scan type %T", value)
}
}
func stringPtr(v string) *string {
return &v
}

View File

@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}
if userIn.LastLoginAt != nil {
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
}
if userIn.LastActiveAt != nil {
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
}
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
var field string
defaultField := true
nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
case "last_login_at":
field = dbuser.FieldLastLoginAt
defaultField = false
nullsLastField = true
case "last_active_at":
field = dbuser.FieldLastActiveAt
defaultField = false
nullsLastField = true
default:
field = dbuser.FieldID
}
@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
dbent.Asc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
dbent.Desc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
dst.SignupSource = src.SignupSource
dst.LastLoginAt = src.LastLoginAt
dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource)
if signupSource == "" {
return "email"
}
return signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)

View File

@ -4,6 +4,7 @@ package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID)
}
func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
created := s.mustCreateUser(&service.User{
Email: "identity-meta@example.com",
SignupSource: "github",
LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
})
got, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err)
s.Require().Equal("github", got.SignupSource)
s.Require().NotNil(got.LastLoginAt)
s.Require().NotNil(got.LastActiveAt)
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
created.SignupSource = "oidc"
created.LastLoginAt = &lastLoginAt
created.LastActiveAt = &lastActiveAt
s.Require().NoError(s.repo.Update(s.ctx, created))
got, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err)
s.Require().Equal("oidc", got.SignupSource)
s.Require().NotNil(got.LastLoginAt)
s.Require().NotNil(got.LastActiveAt)
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() {
older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond)
newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond)
s.mustCreateUser(&service.User{Email: "nil-login@example.com"})
s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older})
s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "last_login_at",
SortOrder: "desc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 3)
s.Require().Equal("newer-login@example.com", users[0].Email)
s.Require().Equal("older-login@example.com", users[1].Email)
s.Require().Equal("nil-login@example.com", users[2].Email)
}
func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "last_active_at",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 3)
s.Require().Equal("earlier-active@example.com", users[0].Email)
s.Require().Equal("later-active@example.com", users[1].Email)
s.Require().Equal("nil-active@example.com", users[2].Email)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}

View File

@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "false",
service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": false,
"oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,

View File

@ -64,12 +64,26 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
auth.POST("/oauth/pending/exchange",
rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.ExchangePendingOAuthCompletion,
)
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
auth.POST("/oauth/wechat/complete-registration",
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteWeChatOAuthRegistration,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",

View File

@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}

View File

@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}

View File

@ -0,0 +1,326 @@
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
)
const (
defaultPendingAuthTTL = 15 * time.Minute
defaultPendingAuthCompletionTTL = 5 * time.Minute
)
type PendingAuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type CreatePendingAuthSessionInput struct {
SessionToken string
Intent string
Identity PendingAuthIdentityKey
TargetUserID *int64
RedirectTo string
ResolvedEmail string
RegistrationPasswordHash string
BrowserSessionKey string
UpstreamIdentityClaims map[string]any
LocalFlowState map[string]any
ExpiresAt time.Time
}
type IssuePendingAuthCompletionCodeInput struct {
PendingAuthSessionID int64
BrowserSessionKey string
TTL time.Duration
}
type IssuePendingAuthCompletionCodeResult struct {
Code string
ExpiresAt time.Time
}
type PendingIdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type AuthPendingIdentityService struct {
entClient *dbent.Client
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken := strings.TrimSpace(input.SessionToken)
if sessionToken == "" {
var err error
sessionToken, err = randomOpaqueToken(24)
if err != nil {
return nil, err
}
}
expiresAt := input.ExpiresAt.UTC()
if expiresAt.IsZero() {
expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
}
create := s.entClient.PendingAuthSession.Create().
SetSessionToken(sessionToken).
SetIntent(strings.TrimSpace(input.Intent)).
SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
SetExpiresAt(expiresAt)
if input.TargetUserID != nil {
create = create.SetTargetUserID(*input.TargetUserID)
}
return create.Save(ctx)
}
func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
code, err := randomOpaqueToken(24)
if err != nil {
return nil, err
}
ttl := input.TTL
if ttl <= 0 {
ttl = defaultPendingAuthCompletionTTL
}
expiresAt := time.Now().UTC().Add(ttl)
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeHash(hashPendingAuthCode(code)).
SetCompletionCodeExpiresAt(expiresAt)
if strings.TrimSpace(input.BrowserSessionKey) != "" {
update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
}
if _, err := update.Save(ctx); err != nil {
return nil, err
}
return &IssuePendingAuthCompletionCodeResult{
Code: code,
ExpiresAt: expiresAt,
}, nil
}
func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthCodeInvalid
}
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
}
func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
}
func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken = strings.TrimSpace(sessionToken)
if sessionToken == "" {
return nil, ErrPendingAuthSessionNotFound
}
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) consumeSession(
ctx context.Context,
session *dbent.PendingAuthSession,
browserSessionKey string,
expiredErr error,
consumedErr error,
) (*dbent.PendingAuthSession, error) {
if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
return nil, err
}
now := time.Now().UTC()
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetConsumedAt(now).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx)
if err != nil {
return nil, err
}
return updated, nil
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
if session == nil {
return ErrPendingAuthSessionNotFound
}
now := time.Now().UTC()
if session.ConsumedAt != nil {
return consumedErr
}
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
return expiredErr
}
if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
return expiredErr
}
if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
return ErrPendingAuthBrowserMismatch
}
return nil
}
func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func copyPendingMap(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func randomOpaqueToken(byteLen int) (string, error) {
if byteLen <= 0 {
byteLen = 16
}
buf := make([]byte, byteLen)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func hashPendingAuthCode(code string) string {
sum := sha256.Sum256([]byte(code))
return hex.EncodeToString(sum[:])
}

View File

@ -0,0 +1,224 @@
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return NewAuthPendingIdentityService(client), client
}
func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
targetUser, err := client.User.Create().
SetEmail("pending-target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
TargetUserID: &targetUser.ID,
RedirectTo: "/profile",
ResolvedEmail: "user@example.com",
BrowserSessionKey: "browser-1",
UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
LocalFlowState: map[string]any{"step": "email_required"},
})
require.NoError(t, err)
require.NotEmpty(t, session.SessionToken)
require.Equal(t, "bind_current_user", session.Intent)
require.Equal(t, "wechat", session.ProviderType)
require.NotNil(t, session.TargetUserID)
require.Equal(t, targetUser.ID, *session.TargetUserID)
require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
require.Equal(t, "email_required", session.LocalFlowState["step"])
}
func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expected",
UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
LocalFlowState: map[string]any{"step": "pending"},
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expected",
})
require.NoError(t, err)
require.NotEmpty(t, issued.Code)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.Empty(t, consumed.CompletionCodeHash)
require.Nil(t, consumed.CompletionCodeExpiresAt)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
}
func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expired",
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expired",
TTL: time.Second,
})
require.NoError(t, err)
_, err = client.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-open").
SetProviderSubject("union-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
require.NoError(t, err)
first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
require.NoError(t, err)
require.True(t, first.AdoptDisplayName)
require.False(t, first.AdoptAvatar)
require.Nil(t, first.IdentityID)
second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
require.NoError(t, err)
require.Equal(t, first.ID, second.ID)
require.NotNil(t, second.IdentityID)
require.Equal(t, identity.ID, *second.IdentityID)
require.True(t, second.AdoptAvatar)
}
func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "subject-session-token",
},
BrowserSessionKey: "browser-session",
LocalFlowState: map[string]any{
"completion_response": map[string]any{
"access_token": "token",
},
},
})
require.NoError(t, err)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}

View File

@ -13,6 +13,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -106,6 +107,13 @@ func NewAuthService(
}
}
func (s *AuthService) EntClient() *dbent.Client {
if s == nil {
return nil
}
return s.entClient
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
s.touchUserLogin(ctx, user.ID)
// 生成JWT token
token, err := s.GenerateToken(user)
@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user)
if err != nil {
@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
signupSource = "email"
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
if signupSource == "email" {
s.ensureEmailAuthIdentity(ctx, user)
}
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
}
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
return
}
if err := s.entClient.User.UpdateOneID(userID).
SetSignupSource(signupSource).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
}
}
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
now := time.Now().UTC()
if err := s.entClient.User.UpdateOneID(userID).
SetLastLoginAt(now).
SetLastActiveAt(now).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
}
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return
}
if err := s.entClient.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(email).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{
"source": "auth_service_dual_write",
}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
}
}
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
return "oidc"
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
return "wechat"
default:
return "email"
}
}
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token

View File

@ -0,0 +1,153 @@
//go:build unit
package service_test
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type authIdentitySettingRepoStub struct {
values map[string]string
}
func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-auth-identity-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
},
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
return svc, repo, client
}
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
svc, _, client := newAuthServiceWithEnt(t)
ctx := context.Background()
token, user, err := svc.Register(ctx, "user@example.com", "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, user)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "email", storedUser.SignupSource)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("user@example.com"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
require.NotNil(t, identity.VerifiedAt)
}
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
svc, repo, client := newAuthServiceWithEnt(t)
ctx := context.Background()
user := &service.User{
Email: "login@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 1,
Concurrency: 1,
}
require.NoError(t, user.SetPassword("password"))
require.NoError(t, repo.Create(ctx, user))
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
_, err := client.User.UpdateOneID(user.ID).
SetLastLoginAt(old).
SetLastActiveAt(old).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.After(old))
require.True(t, storedUser.LastActiveAt.After(old))
}

View File

@ -1,146 +0,0 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access tokenJWTClaims无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT模拟旧 token应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}

View File

@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀RFC 保留域名)。
const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
// Setting keys
const (
// 注册设置
@ -153,6 +156,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表JSON
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成

View File

@ -13,14 +13,30 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
const (
openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
)
type cachedOpenAIAdvancedSchedulerSetting struct {
enabled bool
expiresAt int64
}
var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
var openAIAdvancedSchedulerSettingSF singleflight.Group
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
return nil
}
return s.rateLimitService.settingService.settingRepo
}
func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled
}
}
result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled, nil
}
}
enabled := false
if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
defer cancel()
value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
if err == nil {
enabled = strings.EqualFold(strings.TrimSpace(value), "true")
}
}
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: enabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
return enabled, nil
})
enabled, _ := result.(bool)
return enabled
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
openAIAdvancedSchedulerSettingCache = atomic.Value{}
openAIAdvancedSchedulerSettingSF = singleflight.Group{}
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"math"
"sync"
@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
type schedulerTestOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, errors.New("account not found")
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type schedulerTestConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
acquireResults map[int64]bool
waitCounts map[int64]int
skipDefaultLoad bool
}
func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if c.loadBatchErr != nil {
return nil, c.loadBatchErr
}
out := make(map[int64]*AccountLoadInfo, len(accounts))
if c.skipDefaultLoad && c.loadMap != nil {
for _, acc := range accounts {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
}
}
return out, nil
}
for _, acc := range accounts {
if c.loadMap != nil {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
continue
}
}
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
type schedulerTestGatewayCache struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := c.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if c.sessionBindings == nil {
c.sessionBindings = make(map[string]int64)
}
c.sessionBindings[sessionHash] = accountID
return nil
}
func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if c.sessionBindings == nil {
return nil
}
if c.deletedSessions == nil {
c.deletedSessions = make(map[string]int)
}
c.deletedSessions[sessionHash]++
delete(c.sessionBindings, sessionHash)
return nil
}
func newSchedulerTestOpenAIWSV2Config() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
return cfg
}
type openAIAdvancedSchedulerSettingRepoStub struct {
values map[string]string
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
value, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &Setting{Key: key, Value: value}, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if s == nil || s.values == nil {
return "", ErrSettingNotFound
}
value, ok := s.values[key]
if !ok {
return "", ErrSettingNotFound
}
return value, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected call to Set")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected call to GetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected call to SetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected call to GetAll")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected call to Delete")
}
func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
repo := &openAIAdvancedSchedulerSettingRepoStub{
values: map[string]string{},
}
if enabled != "" {
repo.values[openAIAdvancedSchedulerSettingKey] = enabled
}
return &RateLimitService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10106)
accounts := []Account{
{
ID: 36001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
},
{
ID: 36002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cache := &schedulerTestGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_disabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10107)
accounts := []Account{
{
ID: 37001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 37002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_enabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37001), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
require.True(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
cfg := newOpenAIWSV2TestConfig()
cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: newOpenAIWSV2TestConfig(),
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: newSchedulerTestOpenAIWSV2Config(),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
cfg := newOpenAIWSV2TestConfig()
cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,

View File

@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
cache: &stubGatewayCache{},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(

View File

@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
cfg := s.parsePaymentConfig(vals)
if s.entClient != nil {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list enabled provider instances: %w", err)
}
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances))
} else {
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil)
}
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil
@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
types = append(types, t)
}
}
cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
if s.entClient == nil {
return ""
}
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
available := make(map[string]bool, 4)
for _, inst := range instances {
switch inst.ProviderKey {
case payment.TypeAlipay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
available[VisibleMethodSourceOfficialAlipay] = true
}
case payment.TypeWxpay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
available[VisibleMethodSourceOfficialWechat] = true
}
case payment.TypeEasyPay:
for _, supportedType := range splitTypes(inst.SupportedTypes) {
switch NormalizeVisibleMethod(supportedType) {
case payment.TypeAlipay:
available[VisibleMethodSourceEasyPayAlipay] = true
case payment.TypeWxpay:
available[VisibleMethodSourceEasyPayWechat] = true
}
}
}
}
return available
}
func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
shouldExpose := map[string]bool{
payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
}
seen := make(map[string]struct{}, len(base)+2)
out := make([]string, 0, len(base)+2)
appendType := func(paymentType string) {
paymentType = NormalizeVisibleMethod(paymentType)
if paymentType == "" {
return
}
if _, ok := seen[paymentType]; ok {
return
}
seen[paymentType] = struct{}{}
out = append(out, paymentType)
}
for _, paymentType := range base {
visibleMethod := NormalizeVisibleMethod(paymentType)
switch visibleMethod {
case payment.TypeAlipay, payment.TypeWxpay:
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
default:
appendType(visibleMethod)
}
}
for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
}
return out
}
func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
enabledKey := visibleMethodEnabledSettingKey(method)
sourceKey := visibleMethodSourceSettingKey(method)
if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
return false
}
source := NormalizeVisibleMethodSource(method, vals[sourceKey])
return source != "" && available[source]
}

View File

@ -1,9 +1,17 @@
package service
import (
"context"
"database/sql"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 2 {
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
}
})
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
t.Parallel()
base := []string{"alipay", "wxpay", "stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
}
available := map[string]bool{
VisibleMethodSourceOfficialAlipay: true,
VisibleMethodSourceOfficialWechat: false,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"alipay", "stripe"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
t.Parallel()
base := []string{"stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
}
available := map[string]bool{
VisibleMethodSourceEasyPayAlipay: true,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"stripe", "alipay"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
t.Parallel()
instances := []*dbent.PaymentProviderInstance{
{ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
{ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
{ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
}
got := buildVisibleMethodSourceAvailability(instances)
if !got[VisibleMethodSourceOfficialAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
}
if !got[VisibleMethodSourceEasyPayAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
}
if !got[VisibleMethodSourceOfficialWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
}
if !got[VisibleMethodSourceEasyPayWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
}
}
func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: "easypay",
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: "wxpay",
},
},
}
cfg, err := svc.GetPaymentConfig(ctx)
if err != nil {
t.Fatalf("GetPaymentConfig returned error: %v", err)
}
want := []string{payment.TypeAlipay, payment.TypeStripe}
if len(cfg.EnabledTypes) != len(want) {
t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
}
for i := range want {
if cfg.EnabledTypes[i] != want[i] {
t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
}
}
}
func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared")
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
t.Fatalf("enable foreign keys: %v", err)
}
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
type paymentConfigSettingRepoStub struct {
values map[string]string
}
func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
return nil, nil
}
func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }

View File

@ -0,0 +1,248 @@
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
paymentResumeFallbackSigningKey = "sub2api-payment-resume"
SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
VisibleMethodSourceOfficialAlipay = "official_alipay"
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
)
type ResumeTokenClaims struct {
OrderID int64 `json:"oid"`
UserID int64 `json:"uid,omitempty"`
ProviderInstanceID string `json:"pi,omitempty"`
ProviderKey string `json:"pk,omitempty"`
PaymentType string `json:"pt,omitempty"`
CanonicalReturnURL string `json:"ru,omitempty"`
IssuedAt int64 `json:"iat"`
}
type PaymentResumeService struct {
signingKey []byte
}
type visibleMethodLoadBalancer struct {
inner payment.LoadBalancer
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
func NormalizeVisibleMethods(methods []string) []string {
if len(methods) == 0 {
return nil
}
seen := make(map[string]struct{}, len(methods))
out := make([]string, 0, len(methods))
for _, method := range methods {
normalized := NormalizeVisibleMethod(method)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
out = append(out, normalized)
}
return out
}
func NormalizePaymentSource(source string) string {
switch strings.TrimSpace(strings.ToLower(source)) {
case "", PaymentSourceHostedRedirect:
return PaymentSourceHostedRedirect
case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
return PaymentSourceWechatInAppResume
default:
return strings.TrimSpace(strings.ToLower(source))
}
}
func NormalizeVisibleMethodSource(method, source string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
return VisibleMethodSourceOfficialAlipay
case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayAlipay
}
case payment.TypeWxpay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
return VisibleMethodSourceOfficialWechat
case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayWechat
}
}
return ""
}
func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
switch NormalizeVisibleMethodSource(method, source) {
case VisibleMethodSourceOfficialAlipay:
return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceEasyPayAlipay:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceOfficialWechat:
return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
case VisibleMethodSourceEasyPayWechat:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
default:
return "", false
}
}
func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
if inner == nil || configService == nil || configService.settingRepo == nil {
return inner
}
return &visibleMethodLoadBalancer{inner: inner, configService: configService}
}
func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
return lb.inner.GetInstanceConfig(ctx, instanceID)
}
func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
visibleMethod := NormalizeVisibleMethod(paymentType)
if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
}
enabledKey := visibleMethodEnabledSettingKey(visibleMethod)
sourceKey := visibleMethodSourceSettingKey(visibleMethod)
vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey})
if err != nil {
return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err)
}
if vals[enabledKey] != "true" {
return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod)
}
targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey])
if !ok {
return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod)
}
return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount)
}
func visibleMethodEnabledSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipayEnabled
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpayEnabled
default:
return ""
}
}
func visibleMethodSourceSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipaySource
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpaySource
default:
return ""
}
}
func CanonicalizeReturnURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
}
parsed, err := url.Parse(raw)
if err != nil || !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
}
parsed.Fragment = ""
if parsed.Path == "" {
parsed.Path = "/"
}
return parsed.String(), nil
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return encodedPayload + "." + s.sign(encodedPayload), nil
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
var claims ResumeTokenClaims
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
}
if claims.OrderID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
}
return &claims, nil
}
func (s *PaymentResumeService) sign(payload string) string {
key := s.signingKey
if len(key) == 0 {
key = []byte(paymentResumeFallbackSigningKey)
}
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View File

@ -0,0 +1,240 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeVisibleMethods(t *testing.T) {
t.Parallel()
got := NormalizeVisibleMethods([]string{
"alipay_direct",
"alipay",
" wxpay_direct ",
"wxpay",
"stripe",
})
want := []string{"alipay", "wxpay", "stripe"}
if len(got) != len(want) {
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestNormalizePaymentSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expect string
}{
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizePaymentSource(tt.input); got != tt.expect {
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://example.com/pay/result?b=2" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
ProviderInstanceID: "19",
ProviderKey: "easypay",
PaymentType: "wxpay",
CanonicalReturnURL: "https://example.com/payment/result",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatalf("ParseToken returned error: %v", err)
}
if claims.OrderID != 42 || claims.UserID != 7 {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
}
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
input string
want string
}{
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
}
})
}
}
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
source string
want string
ok bool
}{
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
if got != tt.want || ok != tt.ok {
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
}
})
}
}
func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != payment.TypeAlipay {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodWxpayEnabled: "false",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
t.Fatal("SelectInstance should reject disabled visible method")
}
}
type paymentSettingRepoStub struct {
values map[string]string
}
func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
type captureLoadBalancer struct {
lastProviderKey string
lastPaymentType string
}
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
return map[string]string{}, nil
}
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
c.lastProviderKey = providerKey
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}

View File

@ -65,15 +65,17 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
OrderType string
PlanID int64
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
ReturnURL string
PaymentSource string
OrderType string
PlanID int64
}
type CreateOrderResponse struct {
@ -88,6 +90,7 @@ type CreateOrderResponse struct {
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
@ -165,10 +168,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
return svc
}
// --- Provider Registry ---
@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
return &s
}
func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {

View File

@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/url"
"os"
"sort"
"strconv"
"strings"
@ -114,6 +115,66 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
}
type ProviderDefaultGrantSettings struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
GrantOnSignup bool
GrantOnFirstBind bool
}
type AuthSourceDefaultSettings struct {
Email ProviderDefaultGrantSettings
LinuxDo ProviderDefaultGrantSettings
OIDC ProviderDefaultGrantSettings
WeChat ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup bool
}
type authSourceDefaultKeySet struct {
balance string
concurrency string
subscriptions string
grantOnSignup string
grantOnFirstBind string
}
var (
emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultEmailBalance,
concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
}
linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
}
oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultOIDCBalance,
concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
}
weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultWeChatBalance,
concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
}
)
const (
defaultAuthSourceBalance = 0
defaultAuthSourceConcurrency = 5
)
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
weChatEnabled := isWeChatOAuthConfigured()
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
WeChatOAuthEnabled: weChatEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
func isWeChatOAuthConfigured() bool {
openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
return openConfigured || mpConfigured
}
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value)
}
func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
keys := []string{
SettingKeyAuthSourceDefaultEmailBalance,
SettingKeyAuthSourceDefaultEmailConcurrency,
SettingKeyAuthSourceDefaultEmailSubscriptions,
SettingKeyAuthSourceDefaultEmailGrantOnSignup,
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
SettingKeyAuthSourceDefaultLinuxDoBalance,
SettingKeyAuthSourceDefaultLinuxDoConcurrency,
SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
SettingKeyAuthSourceDefaultOIDCBalance,
SettingKeyAuthSourceDefaultOIDCConcurrency,
SettingKeyAuthSourceDefaultOIDCSubscriptions,
SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
SettingKeyAuthSourceDefaultWeChatBalance,
SettingKeyAuthSourceDefaultWeChatConcurrency,
SettingKeyAuthSourceDefaultWeChatSubscriptions,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
SettingKeyForceEmailOnThirdPartySignup,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get auth source default settings: %w", err)
}
return &AuthSourceDefaultSettings{
Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
}, nil
}
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
if settings == nil {
return nil
}
for _, subscriptions := range [][]DefaultSubscriptionSetting{
settings.Email.Subscriptions,
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return err
}
}
updates := make(map[string]string, 21)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return fmt.Errorf("update auth source default settings: %w", err)
}
return nil
}
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultOIDCBalance: "0",
SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true",
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultWeChatBalance: "0",
SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true",
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
SettingKeyForceEmailOnThirdPartySignup: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
result.OIDCConnectUsePKCE = true
result.OIDCConnectValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
result := ProviderDefaultGrantSettings{
Balance: defaultAuthSourceBalance,
Concurrency: defaultAuthSourceConcurrency,
Subscriptions: []DefaultSubscriptionSetting{},
GrantOnSignup: true,
GrantOnFirstBind: false,
}
if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
result.Balance = v
}
if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
result.Concurrency = v
}
if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
result.Subscriptions = items
}
if raw, ok := settings[keys.grantOnSignup]; ok {
result.GrantOnSignup = raw == "true"
}
if raw, ok := settings[keys.grantOnFirstBind]; ok {
result.GrantOnFirstBind = raw == "true"
}
return result
}
func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
subscriptions := settings.Subscriptions
if subscriptions == nil {
subscriptions = []DefaultSubscriptionSetting{}
}
raw, err := json.Marshal(subscriptions)
if err != nil {
raw = []byte("[]")
}
updates[keys.subscriptions] = string(raw)
updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
}
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
effective.UsePKCE = true
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
effective.UsePKCE = true
effective.ValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}

View File

@ -0,0 +1,136 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type authSourceDefaultsRepoStub struct {
values map[string]string
updates map[string]string
}
func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for key, value := range settings {
s.updates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
repo := &authSourceDefaultsRepoStub{
values: map[string]string{
SettingKeyAuthSourceDefaultEmailBalance: "12.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := NewSettingService(repo, &config.Config{})
got, err := svc.GetAuthSourceDefaultSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 12.5, got.Email.Balance)
require.Equal(t, 7, got.Email.Concurrency)
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
require.False(t, got.Email.GrantOnSignup)
require.False(t, got.Email.GrantOnFirstBind)
require.Equal(t, 0.0, got.LinuxDo.Balance)
require.Equal(t, 5, got.LinuxDo.Concurrency)
require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
require.True(t, got.LinuxDo.GrantOnSignup)
require.True(t, got.LinuxDo.GrantOnFirstBind)
require.Equal(t, 5, got.OIDC.Concurrency)
require.Equal(t, 5, got.WeChat.Concurrency)
require.True(t, got.ForceEmailOnThirdPartySignup)
}
func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
repo := &authSourceDefaultsRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
Email: ProviderDefaultGrantSettings{
Balance: 1.25,
Concurrency: 3,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
GrantOnSignup: false,
GrantOnFirstBind: true,
},
LinuxDo: ProviderDefaultGrantSettings{
Balance: 2,
Concurrency: 4,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
GrantOnSignup: true,
GrantOnFirstBind: false,
},
OIDC: ProviderDefaultGrantSettings{
Balance: 3,
Concurrency: 5,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
GrantOnSignup: true,
GrantOnFirstBind: true,
},
WeChat: ProviderDefaultGrantSettings{
Balance: 4,
Concurrency: 6,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
GrantOnSignup: false,
GrantOnFirstBind: false,
},
ForceEmailOnThirdPartySignup: true,
})
require.NoError(t, err)
require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
var got []DefaultSubscriptionSetting
require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
}

View File

@ -152,6 +152,7 @@ type PublicSettings struct {
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
WeChatOAuthEnabled bool
BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool

View File

@ -7,19 +7,27 @@ import (
)
type User struct {
ID int64
Email string
Username string
Notes string
PasswordHash string
Role string
Balance float64
Concurrency int
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Email string
Username string
Notes string
AvatarURL string
AvatarSource string
AvatarMIME string
AvatarByteSize int
AvatarSHA256 string
PasswordHash string
Role string
Balance float64
Concurrency int
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier

View File

@ -2,9 +2,13 @@ package service
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"strings"
"time"
@ -17,10 +21,14 @@ var (
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
)
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
maxNotifyEmails = 3 // Maximum number of notification emails per user
maxInlineAvatarBytes = 100 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
@ -47,6 +55,9 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
@ -71,11 +82,30 @@ type UserRepository interface {
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
type UserAvatar struct {
StorageProvider string
StorageKey string
URL string
ContentType string
ByteSize int
SHA256 string
}
type UpsertUserAvatarInput struct {
StorageProvider string
StorageKey string
URL string
ContentType string
ByteSize int
SHA256 string
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil
}
@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
if req.AvatarURL != nil {
avatarValue := strings.TrimSpace(*req.AvatarURL)
switch {
case avatarValue == "":
if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
return nil, fmt.Errorf("delete avatar: %w", err)
}
applyUserAvatar(user, nil)
default:
avatarInput, err := normalizeUserAvatarInput(avatarValue)
if err != nil {
return nil, err
}
avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
if err != nil {
return nil, fmt.Errorf("upsert avatar: %w", err)
}
applyUserAvatar(user, avatar)
}
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil
}
func applyUserAvatar(user *User, avatar *UserAvatar) {
if user == nil {
return
}
if avatar == nil {
user.AvatarURL = ""
user.AvatarSource = ""
user.AvatarMIME = ""
user.AvatarByteSize = 0
user.AvatarSHA256 = ""
return
}
user.AvatarURL = avatar.URL
user.AvatarSource = avatar.StorageProvider
user.AvatarMIME = avatar.ContentType
user.AvatarByteSize = avatar.ByteSize
user.AvatarSHA256 = avatar.SHA256
}
func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if strings.HasPrefix(raw, "data:") {
return normalizeInlineUserAvatarInput(raw)
}
parsed, err := url.Parse(raw)
if err != nil || parsed == nil {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if strings.TrimSpace(parsed.Host) == "" {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
return UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: raw,
}, nil
}
func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
body := strings.TrimPrefix(raw, "data:")
meta, encoded, ok := strings.Cut(body, ",")
if !ok {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
meta = strings.TrimSpace(meta)
encoded = strings.TrimSpace(encoded)
if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return UpsertUserAvatarInput{}, ErrAvatarNotImage
}
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if len(decoded) > maxInlineAvatarBytes {
return UpsertUserAvatarInput{}, ErrAvatarTooLarge
}
sum := sha256.Sum256(decoded)
return UpsertUserAvatarInput{
StorageProvider: "inline",
URL: raw,
ContentType: contentType,
ByteSize: len(decoded),
SHA256: hex.EncodeToString(sum[:]),
}, nil
}
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil
}
func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
return nil
}
avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
if err != nil {
return err
}
applyUserAvatar(user, avatar)
return nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)

View File

@ -4,6 +4,9 @@ package service
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"sync"
"sync/atomic"
@ -19,14 +22,65 @@ import (
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
getByIDUser *User
getByIDErr error
updateFn func(ctx context.Context, user *User) error
updateCalls int
upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
upsertAvatarArgs []UpsertUserAvatarInput
deleteAvatarFn func(ctx context.Context, userID int64) error
deleteAvatarIDs []int64
getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
if m.getByIDErr != nil {
return nil, m.getByIDErr
}
if m.getByIDUser != nil {
cloned := *m.getByIDUser
return &cloned, nil
}
return &User{}, nil
}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
m.updateCalls++
if m.updateFn != nil {
return m.updateFn(ctx, user)
}
return nil
}
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
if m.getAvatarFn != nil {
return m.getAvatarFn(ctx, userID)
}
return nil, nil
}
func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
if m.upsertAvatarFn != nil {
return m.upsertAvatarFn(ctx, userID, input)
}
return &UserAvatar{
StorageProvider: input.StorageProvider,
StorageKey: input.StorageKey,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
if m.deleteAvatarFn != nil {
return m.deleteAvatarFn(ctx, userID)
}
return nil
}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
raw := []byte("small-avatar")
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
expectedSum := sha256.Sum256(raw)
repo := &mockUserRepo{
getByIDUser: &User{
ID: 7,
Email: "avatar@example.com",
Username: "avatar-user",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
AvatarURL: &dataURL,
})
require.NoError(t, err)
require.Len(t, repo.upsertAvatarArgs, 1)
require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
require.Equal(t, dataURL, updated.AvatarURL)
require.Equal(t, "inline", updated.AvatarSource)
require.Equal(t, "image/png", updated.AvatarMIME)
require.Equal(t, len(raw), updated.AvatarByteSize)
require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
}
func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
raw := make([]byte, maxInlineAvatarBytes+1)
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
repo := &mockUserRepo{
getByIDUser: &User{
ID: 8,
Email: "large-avatar@example.com",
Username: "too-large",
},
}
svc := NewUserService(repo, nil, nil, nil)
_, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
AvatarURL: &dataURL,
})
require.ErrorIs(t, err, ErrAvatarTooLarge)
require.Empty(t, repo.upsertAvatarArgs)
require.Empty(t, repo.deleteAvatarIDs)
require.Zero(t, repo.updateCalls)
}
func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
remoteURL := "https://cdn.example.com/avatar.png"
repo := &mockUserRepo{
getByIDUser: &User{
ID: 9,
Email: "remote-avatar@example.com",
Username: "remote-avatar",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
AvatarURL: &remoteURL,
})
require.NoError(t, err)
require.Len(t, repo.upsertAvatarArgs, 1)
require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
require.Equal(t, remoteURL, updated.AvatarURL)
require.Equal(t, "remote_url", updated.AvatarSource)
require.Zero(t, updated.AvatarByteSize)
}
func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
empty := ""
repo := &mockUserRepo{
getByIDUser: &User{
ID: 10,
Email: "delete-avatar@example.com",
Username: "delete-avatar",
AvatarURL: "https://cdn.example.com/old.png",
AvatarSource: "remote_url",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
AvatarURL: &empty,
})
require.NoError(t, err)
require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
require.Empty(t, repo.upsertAvatarArgs)
require.Empty(t, updated.AvatarURL)
require.Empty(t, updated.AvatarSource)
}
func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 12,
Email: "profile-avatar@example.com",
Username: "profile-avatar",
},
getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
return &UserAvatar{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/profile.png",
}, nil
},
}
svc := NewUserService(repo, nil, nil, nil)
user, err := svc.GetProfile(context.Background(), 12)
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
require.Equal(t, "remote_url", user.AvatarSource)
}

View File

@ -0,0 +1,141 @@
ALTER TABLE users
ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
UPDATE users
SET signup_source = 'email'
WHERE signup_source IS NULL OR signup_source = '';
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'users_signup_source_check'
) THEN
ALTER TABLE users
ADD CONSTRAINT users_signup_source_check
CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
END IF;
END $$;
CREATE TABLE IF NOT EXISTS auth_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider_type VARCHAR(20) NOT NULL,
provider_key TEXT NOT NULL,
provider_subject TEXT NOT NULL,
verified_at TIMESTAMPTZ NULL,
issuer TEXT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT auth_identities_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
);
CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
ON auth_identities (provider_type, provider_key, provider_subject);
CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
ON auth_identities (user_id);
CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
ON auth_identities (user_id, provider_type);
CREATE TABLE IF NOT EXISTS auth_identity_channels (
id BIGSERIAL PRIMARY KEY,
identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
provider_type VARCHAR(20) NOT NULL,
provider_key TEXT NOT NULL,
channel VARCHAR(20) NOT NULL,
channel_app_id TEXT NOT NULL,
channel_subject TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT auth_identity_channels_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
);
CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
ON auth_identity_channels (identity_id);
CREATE TABLE IF NOT EXISTS pending_auth_sessions (
id BIGSERIAL PRIMARY KEY,
session_token VARCHAR(255) NOT NULL,
intent VARCHAR(40) NOT NULL,
provider_type VARCHAR(20) NOT NULL,
provider_key TEXT NOT NULL,
provider_subject TEXT NOT NULL,
target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
redirect_to TEXT NOT NULL DEFAULT '',
resolved_email TEXT NOT NULL DEFAULT '',
registration_password_hash TEXT NOT NULL DEFAULT '',
upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
browser_session_key TEXT NOT NULL DEFAULT '',
completion_code_hash TEXT NOT NULL DEFAULT '',
completion_code_expires_at TIMESTAMPTZ NULL,
email_verified_at TIMESTAMPTZ NULL,
password_verified_at TIMESTAMPTZ NULL,
totp_verified_at TIMESTAMPTZ NULL,
expires_at TIMESTAMPTZ NOT NULL,
consumed_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT pending_auth_sessions_intent_check
CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
CONSTRAINT pending_auth_sessions_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
);
CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
ON pending_auth_sessions (session_token);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
ON pending_auth_sessions (target_user_id);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
ON pending_auth_sessions (expires_at);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
ON pending_auth_sessions (provider_type, provider_key, provider_subject);
CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
ON pending_auth_sessions (completion_code_hash);
CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
id BIGSERIAL PRIMARY KEY,
pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
ON identity_adoption_decisions (pending_auth_session_id);
CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
ON identity_adoption_decisions (identity_id);
CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
id BIGSERIAL PRIMARY KEY,
report_type VARCHAR(40) NOT NULL,
report_key TEXT NOT NULL,
details JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
ON auth_identity_migration_reports (report_type);
CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
ON auth_identity_migration_reports (report_type, report_key);

View File

@ -0,0 +1,125 @@
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
u.id,
'email',
'email',
LOWER(BTRIM(u.email)),
COALESCE(u.updated_at, u.created_at, NOW()),
jsonb_build_object(
'backfill_source', 'users.email',
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND BTRIM(COALESCE(u.email, '')) <> ''
AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
u.id,
'linuxdo',
'linuxdo',
SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
COALESCE(u.updated_at, u.created_at, NOW()),
jsonb_build_object(
'backfill_source', 'synthetic_email',
'legacy_email', BTRIM(u.email),
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
u.id,
'wechat',
'wechat',
SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
COALESCE(u.updated_at, u.created_at, NOW()),
jsonb_build_object(
'backfill_source', 'synthetic_email',
'legacy_email', BTRIM(u.email),
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
UPDATE users
SET signup_source = 'linuxdo'
WHERE deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
UPDATE users
SET signup_source = 'wechat'
WHERE deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
UPDATE users
SET signup_source = 'oidc'
WHERE deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'oidc_synthetic_email_requires_manual_recovery',
CAST(u.id AS TEXT),
jsonb_build_object(
'user_id', u.id,
'email', LOWER(BTRIM(u.email)),
'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
ON CONFLICT (report_type, report_key) DO NOTHING;
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'wechat_openid_only_requires_remediation',
CAST(u.id AS TEXT),
jsonb_build_object(
'user_id', u.id,
'email', LOWER(BTRIM(u.email)),
'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
'migration', '109_auth_identity_compat_backfill'
)
FROM users AS u
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
AND NOT EXISTS (
SELECT 1
FROM auth_identities ai
WHERE ai.user_id = u.id
AND ai.provider_type = 'wechat'
AND ai.provider_key = 'wechat'
)
ON CONFLICT (report_type, report_key) DO NOTHING;

View File

@ -0,0 +1,60 @@
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider_type VARCHAR(20) NOT NULL,
grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT user_provider_default_grants_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
CONSTRAINT user_provider_default_grants_reason_check
CHECK (grant_reason IN ('signup', 'first_bind'))
);
CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
ON user_provider_default_grants (user_id, provider_type, grant_reason);
CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
ON user_provider_default_grants (user_id);
CREATE TABLE IF NOT EXISTS user_avatars (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
storage_key TEXT NOT NULL DEFAULT '',
url TEXT NOT NULL DEFAULT '',
content_type VARCHAR(100) NOT NULL DEFAULT '',
byte_size INT NOT NULL DEFAULT 0,
sha256 VARCHAR(64) NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
ON user_avatars (user_id);
INSERT INTO settings (key, value)
VALUES
('auth_source_default_email_balance', '0'),
('auth_source_default_email_concurrency', '5'),
('auth_source_default_email_subscriptions', '[]'),
('auth_source_default_email_grant_on_signup', 'true'),
('auth_source_default_email_grant_on_first_bind', 'false'),
('auth_source_default_linuxdo_balance', '0'),
('auth_source_default_linuxdo_concurrency', '5'),
('auth_source_default_linuxdo_subscriptions', '[]'),
('auth_source_default_linuxdo_grant_on_signup', 'true'),
('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
('auth_source_default_oidc_balance', '0'),
('auth_source_default_oidc_concurrency', '5'),
('auth_source_default_oidc_subscriptions', '[]'),
('auth_source_default_oidc_grant_on_signup', 'true'),
('auth_source_default_oidc_grant_on_first_bind', 'false'),
('auth_source_default_wechat_balance', '0'),
('auth_source_default_wechat_concurrency', '5'),
('auth_source_default_wechat_subscriptions', '[]'),
('auth_source_default_wechat_grant_on_signup', 'true'),
('auth_source_default_wechat_grant_on_first_bind', 'false'),
('force_email_on_third_party_signup', 'false')
ON CONFLICT (key) DO NOTHING;

View File

@ -0,0 +1,8 @@
INSERT INTO settings (key, value)
VALUES
('payment_visible_method_alipay_source', ''),
('payment_visible_method_wxpay_source', ''),
('payment_visible_method_alipay_enabled', 'false'),
('payment_visible_method_wxpay_enabled', 'false'),
('openai_advanced_scheduler_enabled', 'false')
ON CONFLICT (key) DO NOTHING;

Some files were not shown because too many files have changed in this diff Show More