DaydreamCoding b19da9c7fe feat(dingtalk): 钉钉 OAuth 登录接入与 internal_only 用户属性同步
⚠️ 应用类型约束:当前实现仅支持「钉钉登录-企业内部应用」(DingTalk 开放平台
internal_app 类型)。第三方个人应用、第三方企业应用类型暂不支持——OAuth 流程
相同但 corp 校验、跨企业行为不同。backend 通过 DingTalkAppKind 校验对非
internal_app 类型 fail-closed(硬约束)。

钉钉 OAuth 登录主链
- 4 步 OAuth 链:ExchangeCodeForUserToken / GetUnionIdByUserToken /
  GetUserIdByUnionId / GetStaffInfoByUserId;app token 缓存
- pending session 机制持久化 OAuth 中间态;cookie-only token 持久化
- 三种分流:bind_login_required / email_completion / choose_account_action
- corp_restriction_policy 支持 none + internal_only;stale "whitelist" 在
  加载层与写入层均静默 coerce 为 none + slog.Warn
- bypass_registration 开关:企业内部模式豁免全局 REGISTRATION_DISABLED
- isReservedEmail / signup_source / canUnbindProvider / OAuth pending flow
  等横切点支持 dingtalk provider
- migration 136:4 表 CHECK 约束加入 'dingtalk' provider 值

internal_only 模式同步企业邮箱/姓名/部门到用户属性
- SyncCorpEmail / SyncDisplayName / SyncDept 三个独立开关 + 对应
  SyncXxxAttrKey 目标属性 key(默认 dingtalk_email / dingtalk_name /
  dingtalk_department);非 internal_only policy 在写入层与加载层均
  coerce 为 false,admin handler 与 setting_service 双层兜底
- 同步语义:首次注册写 users.username(昵称优先 → 企业姓名 fallback),
  之后每次登录刷新 3 个属性;空值也写入以覆盖旧值
- 邮箱三级 fallback:org_email > email > extension["企业邮箱"]
  (钉钉自定义字段 JSON)
- 部门路径递归向上拼接,跳过 dept_id=1 选首个真实子部门,剥离根组织名
- GetUnionIdByUserToken 同时返回 OIDC /contact/users/me 的 nick 字段;
  新增 GetDeptInfo 调用 OAPI /topapi/v2/department/get
- AuthHandler 注入 UserAttributeService;OAuth pending flow 在
  createPendingOAuthAccount / bindPendingOAuthLogin 分别派发到
  AfterRegistration(syncUsername=true)/ AfterLogin
- migration 137 seed dingtalk_email/name/department 三个用户属性定义

附带修复(同集成路径暴露的两个 OAuth 注册回归)
- LoginOrRegisterOAuthWithTokenPair 新建用户分支用 inferLegacySignupSource
  覆写 caller 显式传入的 signupSource,导致 dingtalk/linuxdo/oidc/wechat
  渠道授权按 email 渠道读取;改为只在 caller 未显式传入时回退邮箱推断
- mergeProviderDefaultGrantSettings 把 parse fallback 默认值
  (Concurrency=5 / Balance=0) 当作"未配置"哨兵,admin 显式设 5 时被误判
  退回全局默认(复现:全局默认 1 + 渠道默认并发 5 + grant_on_signup → 新
  用户实际 concurrency=1);去掉哨兵,admin 任何 >=0 值都覆盖 globalDefaults

前端
- DingTalk Login / Callback / EmailCompletion / ChoiceAccount / Error
  视图;router + auth API client
- admin SettingsView:corp policy radio(none / internal_only)+ bypass
  注册开关 + i18n;internal_only 下展示三同步开关 + 目标 attr key 下拉
  (拉取 user attribute definitions),展示 fieldEmail /
  qyapi_get_department_list 钉钉权限申请提示
- Profile:S1 主动绑定 / S5 解绑钉钉按钮 + 合成邮箱防自锁

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-19 15:27:47 +08:00

1014 lines
28 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
return err
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
SetRpmLimit(userIn.RPMLimit).
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
applyUserEntityToService(userIn, created)
return nil
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err != nil {
return nil, err
}
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
matches, err := r.client.User.Query().
Where(userEmailLookupPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
if len(matches) == 0 {
return nil, service.ErrUserNotFound
}
if len(matches) > 1 {
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
}
m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
return err
}
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
oldEmail := existing.Email
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged).
SetRpmLimit(userIn.RPMLimit)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}
if userIn.LastLoginAt != nil {
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
}
if userIn.LastActiveAt != nil {
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
}
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
userIn.UpdatedAt = updated.UpdatedAt
return nil
}
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {
return nil
}
subject := normalizeEmailAuthIdentitySubject(email)
if subject == "" {
return nil
}
if err := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(subject).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": source}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return err
}
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(subject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity.UserID != userID {
return ErrAuthIdentityOwnershipConflict
}
return nil
}
func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
newSubject := normalizeEmailAuthIdentitySubject(newEmail)
if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
return err
}
oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
if oldSubject == "" || oldSubject == newSubject {
return nil
}
_, err := clientFromContext(ctx, client).AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(oldSubject),
).
Exec(ctx)
return err
}
func normalizeEmailAuthIdentitySubject(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return ""
}
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.DingTalkConnectSyntheticEmailDomain) {
return ""
}
return normalized
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
identityIDs, err := txClient.AuthIdentity.Query().
Where(authidentity.UserIDEQ(id)).
IDs(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if len(identityIDs) > 0 {
if _, err := txClient.IdentityAdoptionDecision.Update().
Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
ClearIdentityID().
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if _, err := txClient.AuthIdentityChannel.Delete().
Where(authidentitychannel.IdentityIDIn(identityIDs...)).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if _, err := txClient.AuthIdentity.Delete().
Where(authidentity.UserIDEQ(id)).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
}
affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
if tx != nil {
if err := tx.Commit(); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
}
return nil
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, service.UserListFilters{})
}
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query()
if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(filters.Status))
}
if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(filters.Role))
}
if filters.Search != "" {
q = q.Where(
dbuser.Or(
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
dbuser.NotesContainsFold(filters.Search),
dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)),
),
)
}
if filters.GroupName != "" {
q = q.Where(dbuser.HasAllowedGroupsWith(
dbgroup.NameContainsFold(filters.GroupName),
))
}
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
usersQuery := q.
Offset(params.Offset()).
Limit(params.Limit())
for _, order := range userListOrder(params) {
usersQuery = usersQuery.Order(order)
}
users, err := usersQuery.All(ctx)
if err != nil {
return nil, nil, err
}
outUsers := make([]service.User, 0, len(users))
if len(users) == 0 {
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
userIDs := make([]int64, 0, len(users))
userMap := make(map[int64]*service.User, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
u := userEntityToService(users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
if shouldLoadSubscriptions {
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
}
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err != nil {
return nil, nil, err
}
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
if sortBy == "last_used_at" {
return userLastUsedAtOrder(sortOrder)
}
var field string
defaultField := true
nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
defaultField = false
case "username":
field = dbuser.FieldUsername
defaultField = false
case "role":
field = dbuser.FieldRole
defaultField = false
case "balance":
field = dbuser.FieldBalance
defaultField = false
case "concurrency":
field = dbuser.FieldConcurrency
defaultField = false
case "status":
field = dbuser.FieldStatus
defaultField = false
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
case "last_active_at":
field = dbuser.FieldLastActiveAt
defaultField = false
nullsLastField = true
default:
field = dbuser.FieldID
}
if sortOrder == pagination.SortOrderAsc {
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
dbent.Asc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
dbent.Desc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
result := make(map[int64]*time.Time, len(userIDs))
if len(userIDs) == 0 {
return result, nil
}
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
const query = `
SELECT user_id, MAX(created_at) AS last_used_at
FROM usage_logs
WHERE user_id = ANY($1)
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var (
userID int64
lastUsedAt time.Time
)
if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
return nil, scanErr
}
ts := lastUsedAt.UTC()
result[userID] = &ts
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
if err != nil {
return nil, err
}
return latestByUserID[userID], nil
}
func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
return func(s *entsql.Selector) {
subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
}
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){
orderExpr("ASC", "FIRST", entsql.Asc),
}
}
return []func(*entsql.Selector){
orderExpr("DESC", "LAST", entsql.Desc),
}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil, nil
}
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
// Track cumulative recharge amount for percentage-based notifications
if amount > 0 {
update = update.AddTotalRecharged(amount)
}
n, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id)).
AddBalance(-amount).
Save(ctx)
if err != nil {
return err
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) {
if len(userIDs) == 0 {
return 0, nil
}
if value < 0 {
value = 0
}
res, err := r.sql.ExecContext(ctx,
"UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
value, pq.Array(userIDs))
if err != nil {
return 0, fmt.Errorf("batch set concurrency: %w", err)
}
affected, _ := res.RowsAffected()
return int(affected), nil
}
func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) {
if len(userIDs) == 0 {
return 0, nil
}
res, err := r.sql.ExecContext(ctx,
"UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL",
delta, pq.Array(userIDs))
if err != nil {
return 0, fmt.Errorf("batch add concurrency: %w", err)
}
affected, _ := res.RowsAffected()
return int(affected), nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
client = clientFromContext(ctx, client)
if client == nil {
return nil
}
matches, err := client.User.Query().
Where(userEmailLookupPredicate(email)).
All(ctx)
if err != nil {
return err
}
for _, match := range matches {
if match.ID != userID {
return service.ErrEmailExists
}
}
return nil
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return dbuser.EmailEQ(email)
}
return predicate.User(func(s *entsql.Selector) {
s.Where(entsql.P(func(b *entsql.Builder) {
b.WriteString("LOWER(TRIM(").
Ident(s.C(dbuser.FieldEmail)).
WriteString(")) = ").
Arg(normalized)
}))
})
}
func normalizeEmailLookupValue(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func normalizedEmailUniquenessLockKey(email string) string {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return ""
}
return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
if isSQLNoRowsError(err) {
return nil
}
return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
return int64(affected), nil
}
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
func (r *userRepository) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAllowedGroup.Delete().
Where(userallowedgroup.UserIDEQ(userID), userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
return err
}
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
m, err := r.client.User.Query().
Where(
dbuser.RoleEQ(service.RoleAdmin),
dbuser.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbuser.FieldID)).
First(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
out := make(map[int64][]int64, len(userIDs))
if len(userIDs) == 0 {
return out, nil
}
rows, err := r.client.UserAllowedGroup.Query().
Where(userallowedgroup.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
for i := range rows {
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
}
for userID := range out {
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
}
return out, nil
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
}
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
return nil
}
return err
}
}
return nil
}
func applyUserEntityToService(dst *service.User, src *dbent.User) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.SignupSource = src.SignupSource
dst.LastLoginAt = src.LastLoginAt
dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func userSignupSourceOrDefault(signupSource string) string {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc", "dingtalk":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"
}
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
update := client.User.UpdateOneID(userID)
if encryptedSecret == nil {
update = update.ClearTotpSecretEncrypted()
} else {
update = update.SetTotpSecretEncrypted(*encryptedSecret)
}
_, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// EnableTotp 启用用户的 TOTP 双因素认证
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(true).
SetTotpEnabledAt(time.Now()).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}
// DisableTotp 禁用用户的 TOTP 双因素认证
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.User.UpdateOneID(userID).
SetTotpEnabled(false).
ClearTotpEnabledAt().
ClearTotpSecretEncrypted().
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return nil
}