fix: unify email identity sync and retry first-bind defaults
This commit is contained in:
parent
7a9488ff37
commit
ea27ac6fd7
@ -209,14 +209,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
|
|
||||||
return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
|
|
||||||
return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
|
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
|
||||||
client = clientFromContext(ctx, client)
|
client = clientFromContext(ctx, client)
|
||||||
if client == nil || userID <= 0 {
|
if client == nil || userID <= 0 {
|
||||||
|
|||||||
@ -650,9 +650,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
|||||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
|
|
||||||
return nil, fmt.Errorf("sync email auth identity: %w", err)
|
|
||||||
}
|
|
||||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
@ -688,7 +685,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
oldStatus := user.Status
|
oldStatus := user.Status
|
||||||
oldRole := user.Role
|
oldRole := user.Role
|
||||||
oldEmail := user.Email
|
|
||||||
|
|
||||||
if input.Email != "" {
|
if input.Email != "" {
|
||||||
user.Email = input.Email
|
user.Email = input.Email
|
||||||
@ -721,9 +717,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
|
|
||||||
return nil, fmt.Errorf("sync email auth identity: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 同步用户专属分组倍率
|
// 同步用户专属分组倍率
|
||||||
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
||||||
|
|||||||
@ -31,6 +31,8 @@ type emailSyncRepoStub struct {
|
|||||||
updated []*User
|
updated []*User
|
||||||
ensureCalls []ensureEmailCall
|
ensureCalls []ensureEmailCall
|
||||||
replaceCalls []replaceEmailCall
|
replaceCalls []replaceEmailCall
|
||||||
|
ensureErr error
|
||||||
|
replaceErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
|
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
|
||||||
@ -125,7 +127,7 @@ func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return n
|
|||||||
|
|
||||||
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
|
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
|
||||||
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
|
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
|
||||||
return nil
|
return s.ensureErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
|
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
|
||||||
@ -134,11 +136,14 @@ func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID i
|
|||||||
oldEmail: oldEmail,
|
oldEmail: oldEmail,
|
||||||
newEmail: newEmail,
|
newEmail: newEmail,
|
||||||
})
|
})
|
||||||
return nil
|
return s.replaceErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
|
func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
|
||||||
repo := &emailSyncRepoStub{nextID: 55}
|
repo := &emailSyncRepoStub{
|
||||||
|
nextID: 55,
|
||||||
|
ensureErr: fmt.Errorf("unexpected email resync"),
|
||||||
|
}
|
||||||
svc := &adminServiceImpl{userRepo: repo}
|
svc := &adminServiceImpl{userRepo: repo}
|
||||||
|
|
||||||
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||||
@ -147,14 +152,12 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, user)
|
require.NotNil(t, user)
|
||||||
require.Equal(t, []ensureEmailCall{{
|
require.Equal(t, int64(55), user.ID)
|
||||||
userID: 55,
|
require.Empty(t, repo.ensureCalls)
|
||||||
email: "admin-created@example.com",
|
|
||||||
}}, repo.ensureCalls)
|
|
||||||
require.Empty(t, repo.replaceCalls)
|
require.Empty(t, repo.replaceCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
|
func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
|
||||||
repo := &emailSyncRepoStub{
|
repo := &emailSyncRepoStub{
|
||||||
user: &User{
|
user: &User{
|
||||||
ID: 91,
|
ID: 91,
|
||||||
@ -163,6 +166,7 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
|
|||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Concurrency: 3,
|
Concurrency: 3,
|
||||||
},
|
},
|
||||||
|
replaceErr: fmt.Errorf("unexpected email resync"),
|
||||||
}
|
}
|
||||||
svc := &adminServiceImpl{userRepo: repo}
|
svc := &adminServiceImpl{userRepo: repo}
|
||||||
|
|
||||||
@ -172,10 +176,6 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, updated)
|
require.NotNil(t, updated)
|
||||||
require.Equal(t, "after@example.com", updated.Email)
|
require.Equal(t, "after@example.com", updated.Email)
|
||||||
require.Equal(t, []replaceEmailCall{{
|
require.Empty(t, repo.replaceCalls)
|
||||||
userID: 91,
|
|
||||||
oldEmail: "before@example.com",
|
|
||||||
newEmail: "after@example.com",
|
|
||||||
}}, repo.replaceCalls)
|
|
||||||
require.Empty(t, repo.ensureCalls)
|
require.Empty(t, repo.ensureCalls)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -768,9 +768,6 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
|
|||||||
}
|
}
|
||||||
s.updateUserSignupSource(ctx, user.ID, signupSource)
|
s.updateUserSignupSource(ctx, user.ID, signupSource)
|
||||||
|
|
||||||
if signupSource == "email" {
|
|
||||||
s.ensureEmailAuthIdentity(ctx, user)
|
|
||||||
}
|
|
||||||
if touchLogin {
|
if touchLogin {
|
||||||
s.touchUserLogin(ctx, user.ID)
|
s.touchUserLogin(ctx, user.ID)
|
||||||
}
|
}
|
||||||
@ -807,21 +804,81 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
|
|||||||
if s == nil || user == nil || user.ID <= 0 {
|
if s == nil || user == nil || user.ID <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.ensureEmailAuthIdentity(ctx, user) {
|
identity, created := s.ensureEmailAuthIdentity(ctx, user)
|
||||||
|
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
|
||||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
|
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
|
func (s *AuthService) shouldApplyEmailFirstBindDefaults(
|
||||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
identity *dbent.AuthIdentity,
|
||||||
|
created bool,
|
||||||
|
) bool {
|
||||||
|
if created {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !hasGrant
|
||||||
|
}
|
||||||
|
|
||||||
|
func emailAuthIdentitySource(metadata map[string]any) string {
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw, ok := metadata["source"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(fmt.Sprint(raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) hasProviderGrantRecord(
|
||||||
|
ctx context.Context,
|
||||||
|
userID int64,
|
||||||
|
providerType string,
|
||||||
|
grantReason string,
|
||||||
|
) (bool, error) {
|
||||||
|
if s == nil || s.entClient == nil || userID <= 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.entClient.QueryContext(
|
||||||
|
ctx,
|
||||||
|
`SELECT 1 FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ? LIMIT 1`,
|
||||||
|
userID,
|
||||||
|
strings.TrimSpace(providerType),
|
||||||
|
strings.TrimSpace(grantReason),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return rows.Next(), rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) {
|
||||||
|
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||||
if email == "" || isReservedEmail(email) {
|
if email == "" || isReservedEmail(email) {
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
client := s.entClient
|
client := s.entClient
|
||||||
@ -840,7 +897,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
|
|||||||
existed, err := buildQuery().Exist(ctx)
|
existed, err := buildQuery().Exist(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !existed {
|
if !existed {
|
||||||
@ -861,21 +918,21 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
|
|||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := buildQuery().Only(ctx)
|
identity, err := buildQuery().Only(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
if identity.UserID != user.ID {
|
if identity.UserID != user.ID {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
|
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return !existed
|
return identity, !existed
|
||||||
}
|
}
|
||||||
|
|
||||||
func inferLegacySignupSource(email string) string {
|
func inferLegacySignupSource(email string) string {
|
||||||
|
|||||||
@ -5,6 +5,7 @@ package service_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -34,6 +35,24 @@ func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
|
|||||||
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type flakyAuthIdentityDefaultSubAssignerStub struct {
|
||||||
|
failuresRemaining int
|
||||||
|
calls []*service.AssignSubscriptionInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||||
|
_ context.Context,
|
||||||
|
input *service.AssignSubscriptionInput,
|
||||||
|
) (*service.UserSubscription, bool, error) {
|
||||||
|
cloned := *input
|
||||||
|
s.calls = append(s.calls, &cloned)
|
||||||
|
if s.failuresRemaining > 0 {
|
||||||
|
s.failuresRemaining--
|
||||||
|
return nil, false, errors.New("temporary assign failure")
|
||||||
|
}
|
||||||
|
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
type authIdentitySettingRepoStub struct {
|
type authIdentitySettingRepoStub struct {
|
||||||
values map[string]string
|
values map[string]string
|
||||||
}
|
}
|
||||||
@ -333,6 +352,55 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
|
|||||||
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) {
|
||||||
|
assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
|
||||||
|
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||||
|
}, assigner)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := svc.HashPassword("password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("retry-first-bind@example.com").
|
||||||
|
SetUsername("retry-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetBalance(1.5).
|
||||||
|
SetConcurrency(2).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
require.NotNil(t, gotUser)
|
||||||
|
|
||||||
|
storedUser, err := client.User.Get(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1.5, storedUser.Balance)
|
||||||
|
require.Equal(t, 2, storedUser.Concurrency)
|
||||||
|
require.Len(t, assigner.calls, 1)
|
||||||
|
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
|
|
||||||
|
token, gotUser, err = svc.Login(ctx, user.Email, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
require.NotNil(t, gotUser)
|
||||||
|
|
||||||
|
storedUser, err = client.User.Get(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 10.0, storedUser.Balance)
|
||||||
|
require.Equal(t, 6, storedUser.Concurrency)
|
||||||
|
require.Len(t, assigner.calls, 2)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
|
}
|
||||||
|
|
||||||
func countProviderGrantRecords(
|
func countProviderGrantRecords(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
client *dbent.Client,
|
client *dbent.Client,
|
||||||
|
|||||||
@ -161,33 +161,6 @@ type userAuthIdentityReader interface {
|
|||||||
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type emailAuthIdentitySynchronizer interface {
|
|
||||||
EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
|
|
||||||
ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
|
|
||||||
syncer, ok := repo.(emailAuthIdentitySynchronizer)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
|
|
||||||
}
|
|
||||||
|
|
||||||
func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
|
|
||||||
oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
|
|
||||||
newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
|
|
||||||
if oldNormalized == newNormalized {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
syncer, ok := repo.(emailAuthIdentitySynchronizer)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangePasswordRequest 修改密码请求
|
// ChangePasswordRequest 修改密码请求
|
||||||
type ChangePasswordRequest struct {
|
type ChangePasswordRequest struct {
|
||||||
CurrentPassword string `json:"current_password"`
|
CurrentPassword string `json:"current_password"`
|
||||||
@ -281,7 +254,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
return nil, fmt.Errorf("get user: %w", err)
|
return nil, fmt.Errorf("get user: %w", err)
|
||||||
}
|
}
|
||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
oldEmail := user.Email
|
|
||||||
|
|
||||||
// 更新字段
|
// 更新字段
|
||||||
if req.Email != nil {
|
if req.Email != nil {
|
||||||
@ -326,9 +298,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, fmt.Errorf("update user: %w", err)
|
return nil, fmt.Errorf("update user: %w", err)
|
||||||
}
|
}
|
||||||
if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
|
|
||||||
return nil, fmt.Errorf("sync email auth identity: %w", err)
|
|
||||||
}
|
|
||||||
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
|
func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
|
||||||
repo := &emailSyncRepoStub{
|
repo := &emailSyncRepoStub{
|
||||||
user: &User{
|
user: &User{
|
||||||
ID: 19,
|
ID: 19,
|
||||||
@ -17,6 +17,7 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
|
|||||||
Username: "tester",
|
Username: "tester",
|
||||||
Concurrency: 2,
|
Concurrency: 2,
|
||||||
},
|
},
|
||||||
|
replaceErr: context.DeadlineExceeded,
|
||||||
}
|
}
|
||||||
svc := NewUserService(repo, nil, nil, nil)
|
svc := NewUserService(repo, nil, nil, nil)
|
||||||
|
|
||||||
@ -28,10 +29,6 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
|
|||||||
require.NotNil(t, updated)
|
require.NotNil(t, updated)
|
||||||
require.Equal(t, newEmail, updated.Email)
|
require.Equal(t, newEmail, updated.Email)
|
||||||
require.Equal(t, 1, repo.updateCalls)
|
require.Equal(t, 1, repo.updateCalls)
|
||||||
require.Equal(t, []replaceEmailCall{{
|
require.Empty(t, repo.replaceCalls)
|
||||||
userID: 19,
|
|
||||||
oldEmail: "profile-before@example.com",
|
|
||||||
newEmail: "profile-after@example.com",
|
|
||||||
}}, repo.replaceCalls)
|
|
||||||
require.Empty(t, repo.ensureCalls)
|
require.Empty(t, repo.ensureCalls)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user