fix: 完善邮箱快捷登录注册流程
This commit is contained in:
parent
81edaa8986
commit
e69256a706
@ -9,6 +9,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||||
@ -168,10 +169,22 @@ func (h *AuthHandler) emailOAuthCallbackWithProfile(
|
|||||||
UpstreamMetadata: profile.Metadata,
|
UpstreamMetadata: profile.Metadata,
|
||||||
}
|
}
|
||||||
affiliateCode := h.emailOAuthAffiliateCode(c)
|
affiliateCode := h.emailOAuthAffiliateCode(c)
|
||||||
|
if shouldCreate, err := h.emailOAuthShouldCreatePendingRegistration(c.Request.Context(), input); err != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(err), infraerrors.Message(err), "")
|
||||||
|
return
|
||||||
|
} else if shouldCreate {
|
||||||
|
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||||
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectToFrontendCallback(c, frontendCallback)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
|
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(c.Request.Context(), input, "", affiliateCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||||
if pendingErr := h.createEmailOAuthInvitationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
if pendingErr := h.createEmailOAuthRegistrationPendingSession(c, provider, frontendCallback, redirectTo, profile); pendingErr != nil {
|
||||||
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
redirectOAuthError(c, frontendCallback, infraerrors.Reason(pendingErr), infraerrors.Message(pendingErr), "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -195,6 +208,35 @@ func (h *AuthHandler) emailOAuthCallbackWithProfile(
|
|||||||
redirectWithFragment(c, frontendCallback, fragment)
|
redirectWithFragment(c, frontendCallback, fragment)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) emailOAuthShouldCreatePendingRegistration(ctx context.Context, input service.EmailOAuthIdentityInput) (bool, error) {
|
||||||
|
client := h.entClient()
|
||||||
|
if client == nil {
|
||||||
|
return false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||||
|
}
|
||||||
|
identityUser, err := h.findOAuthIdentityUser(ctx, service.PendingAuthIdentityKey{
|
||||||
|
ProviderType: strings.TrimSpace(input.ProviderType),
|
||||||
|
ProviderKey: strings.TrimSpace(input.ProviderKey),
|
||||||
|
ProviderSubject: strings.TrimSpace(input.ProviderSubject),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
email := strings.TrimSpace(strings.ToLower(input.Email))
|
||||||
|
if identityUser != nil {
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(identityUser.Email), email) {
|
||||||
|
return false, infraerrors.Conflict("AUTH_IDENTITY_EMAIL_MISMATCH", "oauth identity belongs to a different email")
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if _, err := findUserByNormalizedEmail(ctx, client, email); err != nil {
|
||||||
|
if errors.Is(err, service.ErrUserNotFound) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
|
func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return ""
|
return ""
|
||||||
@ -205,7 +247,7 @@ func (h *AuthHandler) emailOAuthAffiliateCode(c *gin.Context) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
|
func (h *AuthHandler) createEmailOAuthRegistrationPendingSession(
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
provider string,
|
provider string,
|
||||||
frontendCallback string,
|
frontendCallback string,
|
||||||
@ -247,14 +289,22 @@ func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
invitationRequired := h != nil && h.settingSvc != nil && h.settingSvc.IsInvitationCodeEnabled(c.Request.Context())
|
||||||
|
pendingError := "registration_completion_required"
|
||||||
|
choiceReason := "registration_completion_required"
|
||||||
|
if invitationRequired {
|
||||||
|
pendingError = "invitation_required"
|
||||||
|
choiceReason = "invitation_required"
|
||||||
|
}
|
||||||
completionResponse := map[string]any{
|
completionResponse := map[string]any{
|
||||||
"step": oauthPendingChoiceStep,
|
"step": oauthPendingChoiceStep,
|
||||||
"error": "invitation_required",
|
"error": pendingError,
|
||||||
"choice_reason": "invitation_required",
|
"choice_reason": choiceReason,
|
||||||
"adoption_required": false,
|
"adoption_required": false,
|
||||||
"create_account_allowed": true,
|
"create_account_allowed": true,
|
||||||
"existing_account_bindable": false,
|
"existing_account_bindable": false,
|
||||||
"force_email_on_signup": true,
|
"force_email_on_signup": true,
|
||||||
|
"invitation_required": invitationRequired,
|
||||||
"email": email,
|
"email": email,
|
||||||
"resolved_email": email,
|
"resolved_email": email,
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
@ -276,7 +326,8 @@ func (h *AuthHandler) createEmailOAuthInvitationPendingSession(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type completeEmailOAuthRequest struct {
|
type completeEmailOAuthRequest struct {
|
||||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
|
InvitationCode string `json:"invitation_code,omitempty"`
|
||||||
AffCode string `json:"aff_code,omitempty"`
|
AffCode string `json:"aff_code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,21 +361,12 @@ func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider st
|
|||||||
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
|
affiliateCode = pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code")
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, user, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(
|
tokenPair, user, err := h.authService.RegisterVerifiedOAuthEmailAccount(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
service.EmailOAuthIdentityInput{
|
strings.TrimSpace(session.ResolvedEmail),
|
||||||
ProviderType: strings.TrimSpace(session.ProviderType),
|
req.Password,
|
||||||
ProviderKey: strings.TrimSpace(session.ProviderKey),
|
|
||||||
ProviderSubject: strings.TrimSpace(session.ProviderSubject),
|
|
||||||
Email: strings.TrimSpace(session.ResolvedEmail),
|
|
||||||
EmailVerified: true,
|
|
||||||
Username: pendingSessionStringValue(session.UpstreamIdentityClaims, "username"),
|
|
||||||
DisplayName: pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"),
|
|
||||||
AvatarURL: pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url"),
|
|
||||||
UpstreamMetadata: clonePendingMap(session.UpstreamIdentityClaims),
|
|
||||||
},
|
|
||||||
strings.TrimSpace(req.InvitationCode),
|
strings.TrimSpace(req.InvitationCode),
|
||||||
affiliateCode,
|
strings.TrimSpace(session.ProviderType),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@ -342,13 +384,46 @@ func (h *AuthHandler) completeEmailOAuthRegistration(c *gin.Context, provider st
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
|
||||||
|
sessionForBinding := *session
|
||||||
|
sessionForBinding.UpstreamIdentityClaims = clonePendingMap(session.UpstreamIdentityClaims)
|
||||||
|
if strings.TrimSpace(req.InvitationCode) != "" {
|
||||||
|
sessionForBinding.UpstreamIdentityClaims["invitation_code"] = strings.TrimSpace(req.InvitationCode)
|
||||||
|
}
|
||||||
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{})
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, &sessionForBinding, decision, &user.ID, true, false); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
respondPendingOAuthBindingApplyError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.authService.FinalizeOAuthEmailAccount(
|
||||||
|
txCtx,
|
||||||
|
user,
|
||||||
|
strings.TrimSpace(req.InvitationCode),
|
||||||
|
strings.TrimSpace(session.ProviderType),
|
||||||
|
affiliateCode,
|
||||||
|
); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
|
if err := consumePendingOAuthBrowserSessionTx(c.Request.Context(), tx, session); err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
clearCookies()
|
clearCookies()
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
|
_ = h.authService.RollbackOAuthEmailAccountCreation(c.Request.Context(), user.ID, strings.TrimSpace(req.InvitationCode))
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to consume pending oauth session").WithCause(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -438,17 +513,17 @@ func parseGitHubOAuthProfile(ctx context.Context, cfg config.EmailOAuthProviderC
|
|||||||
if subject == "" {
|
if subject == "" {
|
||||||
return nil, errors.New("github user id is missing")
|
return nil, errors.New("github user id is missing")
|
||||||
}
|
}
|
||||||
email := strings.TrimSpace(gjson.Get(body, "email").String())
|
email := ""
|
||||||
emailVerified := email != ""
|
emailsURL := strings.TrimSpace(cfg.EmailsURL)
|
||||||
if strings.TrimSpace(cfg.EmailsURL) != "" {
|
if emailsURL == "" {
|
||||||
if verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, cfg.EmailsURL, token.AccessToken); err == nil && verifiedEmail != "" {
|
return nil, errors.New("github verified email is missing")
|
||||||
email = verifiedEmail
|
|
||||||
emailVerified = true
|
|
||||||
} else if email == "" && err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if email == "" || !emailVerified {
|
verifiedEmail, err := fetchGitHubPrimaryVerifiedEmail(ctx, emailsURL, token.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
email = verifiedEmail
|
||||||
|
if email == "" {
|
||||||
return nil, errors.New("github verified email is missing")
|
return nil, errors.New("github verified email is missing")
|
||||||
}
|
}
|
||||||
login := strings.TrimSpace(gjson.Get(body, "login").String())
|
login := strings.TrimSpace(gjson.Get(body, "login").String())
|
||||||
|
|||||||
@ -73,6 +73,7 @@ func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *t
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||||
require.Equal(t, "invitation_required", completion["error"])
|
require.Equal(t, "invitation_required", completion["error"])
|
||||||
|
require.Equal(t, true, completion["invitation_required"])
|
||||||
require.Equal(t, "fresh@example.com", completion["email"])
|
require.Equal(t, "fresh@example.com", completion["email"])
|
||||||
require.Equal(t, "fresh@example.com", completion["resolved_email"])
|
require.Equal(t, "fresh@example.com", completion["resolved_email"])
|
||||||
require.Equal(t, true, completion["create_account_allowed"])
|
require.Equal(t, true, completion["create_account_allowed"])
|
||||||
@ -129,7 +130,7 @@ func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T
|
|||||||
_ = user
|
_ = user
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmailOAuthCallbackAutoRegistrationAppliesAffiliateCode(t *testing.T) {
|
func TestEmailOAuthCallbackCreatesPasswordRegistrationSessionForNewEmail(t *testing.T) {
|
||||||
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
|
affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001})
|
||||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
settingValues: map[string]string{
|
settingValues: map[string]string{
|
||||||
@ -161,11 +162,26 @@ func TestEmailOAuthCallbackAutoRegistrationAppliesAffiliateCode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
require.Equal(t, http.StatusFound, recorder.Code)
|
require.Equal(t, http.StatusFound, recorder.Code)
|
||||||
require.Contains(t, recorder.Header().Get("Location"), "access_token=")
|
require.NotContains(t, recorder.Header().Get("Location"), "access_token=")
|
||||||
user, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Only(ctx)
|
userCount, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Count(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []int64{user.ID, user.ID}, affiliateRepo.ensureUserIDs)
|
require.Zero(t, userCount)
|
||||||
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 1001}}, affiliateRepo.bindCalls)
|
require.Empty(t, affiliateRepo.ensureUserIDs)
|
||||||
|
require.Empty(t, affiliateRepo.bindCalls)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Query().Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "aff-user@example.com", session.ResolvedEmail)
|
||||||
|
require.Equal(t, "AFF123", pendingSessionStringValue(session.UpstreamIdentityClaims, "aff_code"))
|
||||||
|
|
||||||
|
completion, ok := readCompletionResponse(session.LocalFlowState)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||||
|
require.Equal(t, "registration_completion_required", completion["error"])
|
||||||
|
require.Equal(t, false, completion["invitation_required"])
|
||||||
|
require.Equal(t, true, completion["create_account_allowed"])
|
||||||
|
require.Equal(t, true, completion["force_email_on_signup"])
|
||||||
|
require.Equal(t, "aff-user@example.com", completion["resolved_email"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
|
func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) {
|
||||||
@ -216,7 +232,7 @@ func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *te
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"invitation_code":"INVITE456"}`))
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"password":"secret-123","invitation_code":"INVITE456","email":"tampered@example.com"}`))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")})
|
||||||
@ -227,6 +243,11 @@ func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *te
|
|||||||
require.Equal(t, http.StatusOK, recorder.Code)
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
|
user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, user.PasswordHash)
|
||||||
|
require.NotEqual(t, "secret-123", user.PasswordHash)
|
||||||
|
tamperedCount, err := client.User.Query().Where(dbuser.EmailEQ("tampered@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, tamperedCount)
|
||||||
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
|
require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls)
|
||||||
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
|
storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -234,6 +255,66 @@ func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *te
|
|||||||
require.Equal(t, user.ID, *storedInvitation.UsedBy)
|
require.Equal(t, user.ID, *storedInvitation.UsedBy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompleteEmailOAuthRegistrationRequiresPassword(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("email-oauth-password-session-token").
|
||||||
|
SetIntent(oauthIntentLogin).
|
||||||
|
SetProviderType("github").
|
||||||
|
SetProviderKey("github").
|
||||||
|
SetProviderSubject("github-password-user").
|
||||||
|
SetResolvedEmail("password-required@example.com").
|
||||||
|
SetRedirectTo("/dashboard").
|
||||||
|
SetBrowserSessionKey("browser-password-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"email": "password-required@example.com",
|
||||||
|
"email_verified": true,
|
||||||
|
"username": "password-required",
|
||||||
|
"provider": "github",
|
||||||
|
"provider_key": "github",
|
||||||
|
"provider_subject": "github-password-user",
|
||||||
|
}).
|
||||||
|
SetLocalFlowState(map[string]any{
|
||||||
|
"step": oauthPendingChoiceStep,
|
||||||
|
"error": "registration_completion_required",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/github/complete-registration", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-password-key")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.completeEmailOAuthRegistration(c, "github")
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
userCount, err := client.User.Query().Where(dbuser.EmailEQ("password-required@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGitHubOAuthProfileRejectsPublicEmailWhenEmailsEndpointFails(t *testing.T) {
|
||||||
|
emailServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
http.Error(w, "missing scope", http.StatusForbidden)
|
||||||
|
}))
|
||||||
|
t.Cleanup(emailServer.Close)
|
||||||
|
|
||||||
|
profile, err := parseGitHubOAuthProfile(context.Background(), config.EmailOAuthProviderConfig{
|
||||||
|
EmailsURL: emailServer.URL,
|
||||||
|
}, &emailOAuthTokenResponse{AccessToken: "token"}, `{"id":123,"login":"octo","email":"public@example.com"}`)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, profile)
|
||||||
|
require.Contains(t, err.Error(), "github emails endpoint status 403")
|
||||||
|
}
|
||||||
|
|
||||||
type oauthEmailAffiliateBindCall struct {
|
type oauthEmailAffiliateBindCall struct {
|
||||||
userID int64
|
userID int64
|
||||||
inviterID int64
|
inviterID int64
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOAuthSignupSource(signupSource string) string {
|
func normalizeOAuthSignupSource(signupSource string) string {
|
||||||
@ -168,6 +169,87 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
|||||||
return tokenPair, user, nil
|
return tokenPair, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterVerifiedOAuthEmailAccount creates a local account from an OAuth
|
||||||
|
// provider that has already returned a verified email address.
|
||||||
|
func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
|
||||||
|
ctx context.Context,
|
||||||
|
email string,
|
||||||
|
password string,
|
||||||
|
invitationCode string,
|
||||||
|
signupSource string,
|
||||||
|
) (*TokenPair, *User, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
|
return nil, nil, ErrRegDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
email = strings.TrimSpace(strings.ToLower(email))
|
||||||
|
if email == "" || len(email) > 255 {
|
||||||
|
return nil, nil, ErrEmailVerifyRequired
|
||||||
|
}
|
||||||
|
if _, err := mail.ParseAddress(email); err != nil {
|
||||||
|
return nil, nil, ErrEmailVerifyRequired
|
||||||
|
}
|
||||||
|
if isReservedEmail(email) {
|
||||||
|
return nil, nil, ErrEmailReserved
|
||||||
|
}
|
||||||
|
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(password) == "" {
|
||||||
|
return nil, nil, infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
|
||||||
|
}
|
||||||
|
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
if existsEmail {
|
||||||
|
return nil, nil, ErrEmailExists
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := s.HashPassword(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signupSource = normalizeOAuthSignupSource(signupSource)
|
||||||
|
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||||
|
var defaultRPMLimit int
|
||||||
|
if s.settingService != nil {
|
||||||
|
defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
|
||||||
|
}
|
||||||
|
user := &User{
|
||||||
|
Email: email,
|
||||||
|
PasswordHash: hashedPassword,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: grantPlan.Balance,
|
||||||
|
Concurrency: grantPlan.Concurrency,
|
||||||
|
RPMLimit: defaultRPMLimit,
|
||||||
|
Status: StatusActive,
|
||||||
|
SignupSource: signupSource,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||||
|
if errors.Is(err, ErrEmailExists) {
|
||||||
|
return nil, nil, ErrEmailExists
|
||||||
|
}
|
||||||
|
return nil, nil, ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||||
|
if err != nil {
|
||||||
|
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
|
||||||
|
return nil, nil, fmt.Errorf("generate token pair: %w", err)
|
||||||
|
}
|
||||||
|
return tokenPair, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
|
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
|
||||||
// only after the pending OAuth flow has fully reached its last reversible step.
|
// only after the pending OAuth flow has fully reached its last reversible step.
|
||||||
func (s *AuthService) FinalizeOAuthEmailAccount(
|
func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||||
|
|||||||
@ -11,31 +11,68 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div v-else-if="needsInvitation" class="card p-6">
|
<div v-else-if="needsRegistrationCompletion" class="card p-6">
|
||||||
<h1 class="text-lg font-semibold text-gray-900 dark:text-white">
|
<h1 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||||
{{ t('auth.oidc.callbackTitle', { providerName }) }}
|
{{ t('auth.oidc.callbackTitle', { providerName }) }}
|
||||||
</h1>
|
</h1>
|
||||||
<p class="mt-2 text-sm text-gray-600 dark:text-gray-400">
|
<p class="mt-2 text-sm text-gray-600 dark:text-gray-400">
|
||||||
{{ t('auth.oidc.invitationRequired', { providerName }) }}
|
{{ registrationHint }}
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div class="mt-6 space-y-4">
|
<div class="mt-6 space-y-4">
|
||||||
<input
|
<div>
|
||||||
v-model="invitationCode"
|
<label class="input-label">{{ t('auth.emailLabel') }}</label>
|
||||||
type="text"
|
<input
|
||||||
class="input w-full"
|
class="input w-full"
|
||||||
:placeholder="t('auth.invitationCodePlaceholder')"
|
type="email"
|
||||||
:disabled="isSubmitting"
|
:value="registrationEmail"
|
||||||
@keyup.enter="handleSubmitInvitation"
|
readonly
|
||||||
/>
|
disabled
|
||||||
<p v-if="invitationError" class="text-sm text-red-600 dark:text-red-400">
|
/>
|
||||||
{{ invitationError }}
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('auth.passwordLabel') }}</label>
|
||||||
|
<input
|
||||||
|
v-model="password"
|
||||||
|
type="password"
|
||||||
|
class="input w-full"
|
||||||
|
:placeholder="t('auth.createPasswordPlaceholder')"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
autocomplete="new-password"
|
||||||
|
@keyup.enter="handleSubmitRegistration"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('auth.confirmPassword') }}</label>
|
||||||
|
<input
|
||||||
|
v-model="confirmPassword"
|
||||||
|
type="password"
|
||||||
|
class="input w-full"
|
||||||
|
:placeholder="t('auth.confirmPasswordPlaceholder')"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
autocomplete="new-password"
|
||||||
|
@keyup.enter="handleSubmitRegistration"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div v-if="invitationRequired">
|
||||||
|
<label class="input-label">{{ t('auth.invitationCodeLabel') }}</label>
|
||||||
|
<input
|
||||||
|
v-model="invitationCode"
|
||||||
|
type="text"
|
||||||
|
class="input w-full"
|
||||||
|
:placeholder="t('auth.invitationCodePlaceholder')"
|
||||||
|
:disabled="isSubmitting"
|
||||||
|
@keyup.enter="handleSubmitRegistration"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<p v-if="registrationError" class="text-sm text-red-600 dark:text-red-400">
|
||||||
|
{{ registrationError }}
|
||||||
</p>
|
</p>
|
||||||
<button
|
<button
|
||||||
class="btn btn-primary w-full"
|
class="btn btn-primary w-full"
|
||||||
type="button"
|
type="button"
|
||||||
:disabled="isSubmitting || !invitationCode.trim()"
|
:disabled="isSubmitting || !canSubmitRegistration"
|
||||||
@click="handleSubmitInvitation"
|
@click="handleSubmitRegistration"
|
||||||
>
|
>
|
||||||
{{ isSubmitting ? t('common.processing') : t('auth.oidc.completeRegistration') }}
|
{{ isSubmitting ? t('common.processing') : t('auth.oidc.completeRegistration') }}
|
||||||
</button>
|
</button>
|
||||||
@ -134,9 +171,13 @@ const appStore = useAppStore()
|
|||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
const isProcessing = ref(false)
|
const isProcessing = ref(false)
|
||||||
const isSubmitting = ref(false)
|
const isSubmitting = ref(false)
|
||||||
const needsInvitation = ref(false)
|
const needsRegistrationCompletion = ref(false)
|
||||||
|
const invitationRequired = ref(false)
|
||||||
|
const registrationEmail = ref('')
|
||||||
|
const password = ref('')
|
||||||
|
const confirmPassword = ref('')
|
||||||
const invitationCode = ref('')
|
const invitationCode = ref('')
|
||||||
const invitationError = ref('')
|
const registrationError = ref('')
|
||||||
const pendingProvider = ref<'github' | 'google'>('github')
|
const pendingProvider = ref<'github' | 'google'>('github')
|
||||||
const redirectTo = ref('/dashboard')
|
const redirectTo = ref('/dashboard')
|
||||||
const invalidCallback = ref(false)
|
const invalidCallback = ref(false)
|
||||||
@ -146,6 +187,9 @@ type EmailOAuthPendingCompletion = Partial<OAuthTokenResponse> & {
|
|||||||
error?: string
|
error?: string
|
||||||
provider?: string
|
provider?: string
|
||||||
redirect?: string
|
redirect?: string
|
||||||
|
email?: string
|
||||||
|
resolved_email?: string
|
||||||
|
invitation_required?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
const code = computed(() => (route.query.code as string) || '')
|
const code = computed(() => (route.query.code as string) || '')
|
||||||
@ -161,6 +205,18 @@ const fullUrl = computed(() => {
|
|||||||
const providerName = computed(() =>
|
const providerName = computed(() =>
|
||||||
pendingProvider.value === 'google' ? 'Google' : 'GitHub'
|
pendingProvider.value === 'google' ? 'Google' : 'GitHub'
|
||||||
)
|
)
|
||||||
|
const registrationHint = computed(() =>
|
||||||
|
invitationRequired.value
|
||||||
|
? t('auth.oidc.invitationRequired', { providerName: providerName.value })
|
||||||
|
: t('auth.oidc.completeRegistration')
|
||||||
|
)
|
||||||
|
const canSubmitRegistration = computed(() => {
|
||||||
|
if (!registrationEmail.value.trim()) return false
|
||||||
|
if (password.value.length < 6) return false
|
||||||
|
if (password.value !== confirmPassword.value) return false
|
||||||
|
if (invitationRequired.value && !invitationCode.value.trim()) return false
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
function parseFragmentParams(): URLSearchParams {
|
function parseFragmentParams(): URLSearchParams {
|
||||||
const raw = typeof window !== 'undefined' ? window.location.hash : ''
|
const raw = typeof window !== 'undefined' ? window.location.hash : ''
|
||||||
@ -247,8 +303,10 @@ async function resumePendingEmailOAuth() {
|
|||||||
}
|
}
|
||||||
redirectTo.value = sanitizeRedirectPath(completionRedirect)
|
redirectTo.value = sanitizeRedirectPath(completionRedirect)
|
||||||
|
|
||||||
if (completion.error === 'invitation_required') {
|
if (completion.error === 'invitation_required' || completion.error === 'registration_completion_required') {
|
||||||
needsInvitation.value = true
|
invitationRequired.value = completion.error === 'invitation_required' || completion.invitation_required === true
|
||||||
|
registrationEmail.value = String(completion.resolved_email || completion.email || '').trim()
|
||||||
|
needsRegistrationCompletion.value = true
|
||||||
isProcessing.value = false
|
isProcessing.value = false
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -260,30 +318,46 @@ async function resumePendingEmailOAuth() {
|
|||||||
appStore.showError(message)
|
appStore.showError(message)
|
||||||
invalidCallback.value = true
|
invalidCallback.value = true
|
||||||
} finally {
|
} finally {
|
||||||
if (!needsInvitation.value) {
|
if (!needsRegistrationCompletion.value) {
|
||||||
isProcessing.value = false
|
isProcessing.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleSubmitInvitation() {
|
async function handleSubmitRegistration() {
|
||||||
invitationError.value = ''
|
registrationError.value = ''
|
||||||
|
if (!registrationEmail.value.trim()) {
|
||||||
|
registrationError.value = t('auth.emailRequired')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (password.value.length < 6) {
|
||||||
|
registrationError.value = t('auth.passwordMinLength')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (password.value !== confirmPassword.value) {
|
||||||
|
registrationError.value = t('auth.passwordsDoNotMatch')
|
||||||
|
return
|
||||||
|
}
|
||||||
const code = invitationCode.value.trim()
|
const code = invitationCode.value.trim()
|
||||||
if (!code) return
|
if (invitationRequired.value && !code) return
|
||||||
|
|
||||||
isSubmitting.value = true
|
isSubmitting.value = true
|
||||||
try {
|
try {
|
||||||
|
const payload: { password: string; invitation_code?: string; aff_code?: string } = {
|
||||||
|
password: password.value,
|
||||||
|
...oauthAffiliatePayload(loadOAuthAffiliateCode())
|
||||||
|
}
|
||||||
|
if (invitationRequired.value) {
|
||||||
|
payload.invitation_code = code
|
||||||
|
}
|
||||||
const { data } = await apiClient.post<OAuthTokenResponse>(
|
const { data } = await apiClient.post<OAuthTokenResponse>(
|
||||||
`/auth/oauth/${pendingProvider.value}/complete-registration`,
|
`/auth/oauth/${pendingProvider.value}/complete-registration`,
|
||||||
{
|
payload
|
||||||
invitation_code: code,
|
|
||||||
...oauthAffiliatePayload(loadOAuthAffiliateCode())
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
await finalizeTokenResponse(data, redirectTo.value)
|
await finalizeTokenResponse(data, redirectTo.value)
|
||||||
} catch (e: unknown) {
|
} catch (e: unknown) {
|
||||||
const err = e as { message?: string; response?: { data?: { message?: string } } }
|
const err = e as { message?: string; response?: { data?: { message?: string } } }
|
||||||
invitationError.value =
|
registrationError.value =
|
||||||
err.response?.data?.message || err.message || t('auth.oidc.completeRegistrationFailed')
|
err.response?.data?.message || err.message || t('auth.oidc.completeRegistrationFailed')
|
||||||
} finally {
|
} finally {
|
||||||
isSubmitting.value = false
|
isSubmitting.value = false
|
||||||
|
|||||||
@ -161,6 +161,8 @@ describe('OAuthCallbackView', () => {
|
|||||||
error: 'invitation_required',
|
error: 'invitation_required',
|
||||||
provider: 'google',
|
provider: 'google',
|
||||||
redirect: '/dashboard',
|
redirect: '/dashboard',
|
||||||
|
resolved_email: 'pending@example.com',
|
||||||
|
invitation_required: true,
|
||||||
})
|
})
|
||||||
apiPostMock.mockResolvedValue({
|
apiPostMock.mockResolvedValue({
|
||||||
data: {
|
data: {
|
||||||
@ -171,14 +173,54 @@ describe('OAuthCallbackView', () => {
|
|||||||
|
|
||||||
const wrapper = mount(OAuthCallbackView)
|
const wrapper = mount(OAuthCallbackView)
|
||||||
await vi.dynamicImportSettled()
|
await vi.dynamicImportSettled()
|
||||||
const input = wrapper.find('input[type="text"]')
|
const passwordInputs = wrapper.findAll('input[type="password"]')
|
||||||
await input.setValue('INVITE456')
|
await passwordInputs[0].setValue('secret-123')
|
||||||
|
await passwordInputs[1].setValue('secret-123')
|
||||||
|
const invitationInput = wrapper.find('input[type="text"]')
|
||||||
|
await invitationInput.setValue('INVITE456')
|
||||||
await wrapper.findAll('button').at(0)?.trigger('click')
|
await wrapper.findAll('button').at(0)?.trigger('click')
|
||||||
|
|
||||||
expect(apiPostMock).toHaveBeenCalledWith('/auth/oauth/google/complete-registration', {
|
expect(apiPostMock).toHaveBeenCalledWith('/auth/oauth/google/complete-registration', {
|
||||||
|
password: 'secret-123',
|
||||||
invitation_code: 'INVITE456',
|
invitation_code: 'INVITE456',
|
||||||
aff_code: 'AFF456',
|
aff_code: 'AFF456',
|
||||||
})
|
})
|
||||||
expect(setTokenMock).toHaveBeenCalledWith('token-1')
|
expect(setTokenMock).toHaveBeenCalledWith('token-1')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('completes email oauth registration with readonly email and without posting email', async () => {
|
||||||
|
routeState.path = '/auth/oauth/callback'
|
||||||
|
exchangePendingOAuthCompletionMock.mockResolvedValue({
|
||||||
|
error: 'registration_completion_required',
|
||||||
|
provider: 'github',
|
||||||
|
redirect: '/dashboard',
|
||||||
|
resolved_email: 'verified@example.com',
|
||||||
|
invitation_required: false,
|
||||||
|
})
|
||||||
|
apiPostMock.mockResolvedValue({
|
||||||
|
data: {
|
||||||
|
access_token: 'token-2',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const wrapper = mount(OAuthCallbackView)
|
||||||
|
await vi.dynamicImportSettled()
|
||||||
|
|
||||||
|
const emailInput = wrapper.find('input[type="email"]')
|
||||||
|
expect(emailInput.exists()).toBe(true)
|
||||||
|
expect((emailInput.element as HTMLInputElement).value).toBe('verified@example.com')
|
||||||
|
expect(emailInput.attributes('readonly')).toBeDefined()
|
||||||
|
expect(emailInput.attributes('disabled')).toBeDefined()
|
||||||
|
|
||||||
|
const passwordInputs = wrapper.findAll('input[type="password"]')
|
||||||
|
await passwordInputs[0].setValue('secret-456')
|
||||||
|
await passwordInputs[1].setValue('secret-456')
|
||||||
|
await wrapper.findAll('button').at(0)?.trigger('click')
|
||||||
|
|
||||||
|
expect(apiPostMock).toHaveBeenCalledWith('/auth/oauth/github/complete-registration', {
|
||||||
|
password: 'secret-456',
|
||||||
|
})
|
||||||
|
expect(apiPostMock.mock.calls[0][1]).not.toHaveProperty('email')
|
||||||
|
expect(setTokenMock).toHaveBeenCalledWith('token-2')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user