package handler import ( "context" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/redeemcode" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) func TestEmailOAuthCallbackRequiresPendingRegistrationWhenInvitationEnabled(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, true) ctx := context.Background() state := "github-oauth-state" recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback?code=code-1&state="+url.QueryEscape(state), nil) req.AddCookie(&http.Cookie{Name: emailOAuthStateCookieName, Value: encodeCookieValue(state)}) req.AddCookie(&http.Cookie{Name: emailOAuthRedirectCookie, Value: encodeCookieValue("/dashboard")}) req.AddCookie(&http.Cookie{Name: emailOAuthProviderCookie, Value: encodeCookieValue("github")}) c.Request = req profile := &emailOAuthProfile{ Subject: "github-123", Email: "fresh@example.com", EmailVerified: true, Username: "fresh", DisplayName: "Fresh User", AvatarURL: "https://cdn.example/fresh.png", Metadata: map[string]any{ "login": "fresh", }, } handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{ Enabled: true, ClientID: "github-client", ClientSecret: "github-secret", RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback", FrontendRedirectURL: "/auth/oauth/callback", }, "/auth/oauth/callback", "/dashboard", profile) require.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") require.Contains(t, location, "/auth/oauth/callback") require.NotContains(t, location, "access_token=") userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) require.NoError(t, err) require.Zero(t, userCount) session, err := client.PendingAuthSession.Query().Only(ctx) require.NoError(t, err) require.Equal(t, "github", session.ProviderType) require.Equal(t, "github", session.ProviderKey) require.Equal(t, "github-123", session.ProviderSubject) require.Equal(t, "fresh@example.com", session.ResolvedEmail) require.Equal(t, "/dashboard", session.RedirectTo) require.Nil(t, session.TargetUserID) completion, ok := readCompletionResponse(session.LocalFlowState) require.True(t, ok) require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, "invitation_required", completion["error"]) require.Equal(t, "fresh@example.com", completion["email"]) require.Equal(t, "fresh@example.com", completion["resolved_email"]) require.Equal(t, true, completion["create_account_allowed"]) require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingSessionCookieName)) require.NotEmpty(t, findSetCookieValue(recorder.Result().Cookies(), oauthPendingBrowserCookieName)) } func TestEmailOAuthCallbackExistingEmailLogsInWhenInvitationEnabled(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, true) ctx := context.Background() user, err := client.User.Create(). SetEmail("existing@example.com"). SetUsername("existing"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/google/callback", nil) handler.emailOAuthCallbackWithProfile(c, "google", config.EmailOAuthProviderConfig{ Enabled: true, ClientID: "google-client", ClientSecret: "google-secret", RedirectURL: "https://app.example/api/v1/auth/oauth/google/callback", FrontendRedirectURL: "/auth/oauth/callback", }, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{ Subject: "google-123", Email: "existing@example.com", EmailVerified: true, Username: "existing", }) require.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") require.Contains(t, location, "access_token=") require.Contains(t, location, "redirect=%252Fdashboard") sessionCount, err := client.PendingAuthSession.Query().Count(ctx) require.NoError(t, err) require.Zero(t, sessionCount) identityCount, err := client.AuthIdentity.Query().Where( authidentity.ProviderTypeEQ("google"), authidentity.ProviderSubjectEQ("google-123"), ).Count(ctx) require.NoError(t, err) require.Equal(t, 1, identityCount) _ = user } func TestEmailOAuthCallbackAutoRegistrationAppliesAffiliateCode(t *testing.T) { affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF123": 1001}) handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ settingValues: map[string]string{ service.SettingKeyAffiliateEnabled: "true", }, affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService { return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil) }, }) ctx := context.Background() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/github/callback", nil) req.AddCookie(&http.Cookie{Name: emailOAuthAffiliateCookie, Value: encodeCookieValue("AFF123")}) c.Request = req handler.emailOAuthCallbackWithProfile(c, "github", config.EmailOAuthProviderConfig{ Enabled: true, ClientID: "github-client", ClientSecret: "github-secret", RedirectURL: "https://app.example/api/v1/auth/oauth/github/callback", FrontendRedirectURL: "/auth/oauth/callback", }, "/auth/oauth/callback", "/dashboard", &emailOAuthProfile{ Subject: "github-aff-user", Email: "aff-user@example.com", EmailVerified: true, Username: "aff-user", }) require.Equal(t, http.StatusFound, recorder.Code) require.Contains(t, recorder.Header().Get("Location"), "access_token=") user, err := client.User.Query().Where(dbuser.EmailEQ("aff-user@example.com")).Only(ctx) require.NoError(t, err) require.Equal(t, []int64{user.ID, user.ID}, affiliateRepo.ensureUserIDs) require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 1001}}, affiliateRepo.bindCalls) } func TestCompleteEmailOAuthRegistrationUsesAffiliateCodeFromPendingSession(t *testing.T) { affiliateRepo := newOAuthEmailAffiliateRepoStub(map[string]int64{"AFF456": 2002}) handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ invitationEnabled: true, settingValues: map[string]string{ service.SettingKeyAffiliateEnabled: "true", }, affiliateFactory: func(_ *dbent.Client, settingSvc *service.SettingService) *service.AffiliateService { return service.NewAffiliateService(affiliateRepo, settingSvc, nil, nil) }, }) ctx := context.Background() invitation, err := client.RedeemCode.Create(). SetCode("INVITE456"). SetType(service.RedeemTypeInvitation). SetStatus(service.StatusUnused). SetValue(0). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("email-oauth-aff-session-token"). SetIntent(oauthIntentLogin). SetProviderType("google"). SetProviderKey("google"). SetProviderSubject("google-aff-user"). SetResolvedEmail("pending-aff@example.com"). SetRedirectTo("/dashboard"). SetBrowserSessionKey("browser-aff-key"). SetUpstreamIdentityClaims(map[string]any{ "email": "pending-aff@example.com", "email_verified": true, "username": "pending-aff", "provider": "google", "provider_key": "google", "provider_subject": "google-aff-user", "aff_code": "AFF456", }). SetLocalFlowState(map[string]any{ "step": oauthPendingChoiceStep, "error": "invitation_required", }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/google/complete-registration", strings.NewReader(`{"invitation_code":"INVITE456"}`)) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-aff-key")}) c.Request = req handler.completeEmailOAuthRegistration(c, "google") require.Equal(t, http.StatusOK, recorder.Code) user, err := client.User.Query().Where(dbuser.EmailEQ("pending-aff@example.com")).Only(ctx) require.NoError(t, err) require.Equal(t, []oauthEmailAffiliateBindCall{{userID: user.ID, inviterID: 2002}}, affiliateRepo.bindCalls) storedInvitation, err := client.RedeemCode.Query().Where(redeemcode.IDEQ(invitation.ID)).Only(ctx) require.NoError(t, err) require.NotNil(t, storedInvitation.UsedBy) require.Equal(t, user.ID, *storedInvitation.UsedBy) } type oauthEmailAffiliateBindCall struct { userID int64 inviterID int64 } type oauthEmailAffiliateRepoStub struct { codeOwners map[string]int64 ensureUserIDs []int64 bindCalls []oauthEmailAffiliateBindCall } func newOAuthEmailAffiliateRepoStub(codeOwners map[string]int64) *oauthEmailAffiliateRepoStub { return &oauthEmailAffiliateRepoStub{codeOwners: codeOwners} } func (r *oauthEmailAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*service.AffiliateSummary, error) { r.ensureUserIDs = append(r.ensureUserIDs, userID) return &service.AffiliateSummary{UserID: userID, AffCode: "SELF"}, nil } func (r *oauthEmailAffiliateRepoStub) GetAffiliateByCode(_ context.Context, code string) (*service.AffiliateSummary, error) { userID, ok := r.codeOwners[strings.ToUpper(strings.TrimSpace(code))] if !ok { return nil, service.ErrAffiliateProfileNotFound } return &service.AffiliateSummary{UserID: userID, AffCode: strings.ToUpper(strings.TrimSpace(code))}, nil } func (r *oauthEmailAffiliateRepoStub) BindInviter(_ context.Context, userID, inviterID int64) (bool, error) { r.bindCalls = append(r.bindCalls, oauthEmailAffiliateBindCall{userID: userID, inviterID: inviterID}) return true, nil } func (r *oauthEmailAffiliateRepoStub) AccrueQuota(context.Context, int64, int64, float64, int, *int64) (bool, error) { panic("unexpected AccrueQuota call") } func (r *oauthEmailAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) { panic("unexpected GetAccruedRebateFromInvitee call") } func (r *oauthEmailAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) { panic("unexpected ThawFrozenQuota call") } func (r *oauthEmailAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) { panic("unexpected TransferQuotaToBalance call") } func (r *oauthEmailAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]service.AffiliateInvitee, error) { panic("unexpected ListInvitees call") } func (r *oauthEmailAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error { panic("unexpected UpdateUserAffCode call") } func (r *oauthEmailAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) { panic("unexpected ResetUserAffCode call") } func (r *oauthEmailAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error { panic("unexpected SetUserRebateRate call") } func (r *oauthEmailAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error { panic("unexpected BatchSetUserRebateRate call") } func (r *oauthEmailAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) { panic("unexpected ListUsersWithCustomSettings call") } func (r *oauthEmailAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) { panic("unexpected ListAffiliateInviteRecords call") } func (r *oauthEmailAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) { panic("unexpected ListAffiliateRebateRecords call") } func (r *oauthEmailAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) { panic("unexpected ListAffiliateTransferRecords call") } func (r *oauthEmailAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*service.AffiliateUserOverview, error) { panic("unexpected GetAffiliateUserOverview call") } func findSetCookieValue(cookies []*http.Cookie, name string) string { for _, cookie := range cookies { if cookie != nil && strings.EqualFold(cookie.Name, name) && cookie.MaxAge >= 0 { return cookie.Value } } return "" }