fix(auth): harden oauth identity upgrade paths
This commit is contained in:
parent
3d29f7c2fa
commit
36aed35957
@ -3,7 +3,9 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
"entgo.io/ent/entc/load"
|
"entgo.io/ent/entc/load"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
|
|||||||
|
|
||||||
userSchema := requireSchema(t, schemas, "User")
|
userSchema := requireSchema(t, schemas, "User")
|
||||||
requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
|
requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
|
||||||
|
signupSource := requireSchemaField(t, userSchema, "signup_source")
|
||||||
|
require.Equal(t, field.TypeString, signupSource.Info.Type)
|
||||||
|
require.True(t, signupSource.Default)
|
||||||
|
require.Equal(t, "email", signupSource.DefaultValue)
|
||||||
|
require.Equal(t, 1, signupSource.Validators)
|
||||||
|
|
||||||
|
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
|
||||||
|
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
|
||||||
|
require.NoError(t, validator(value))
|
||||||
|
}
|
||||||
|
require.Error(t, validator("github"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
|
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
|
||||||
@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for _, schemaField := range schema.Fields {
|
||||||
|
if schemaField.Name == name {
|
||||||
|
return schemaField
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for _, entField := range fields {
|
||||||
|
descriptor := entField.Descriptor()
|
||||||
|
if descriptor.Name != name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
|
||||||
|
validator, ok := descriptor.Validators[0].(func(string) error)
|
||||||
|
require.True(t, ok, "field %s validator should be func(string) error", name)
|
||||||
|
return validator
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Failf(t, "missing field validator", "schema should include field %s", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
|
func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
|
|
||||||
@ -73,7 +75,14 @@ func (User) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
field.String("signup_source").
|
field.String("signup_source").
|
||||||
MaxLen(20).
|
Validate(func(value string) error {
|
||||||
|
switch value {
|
||||||
|
case "email", "linuxdo", "wechat", "oidc":
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
|
||||||
|
}
|
||||||
|
}).
|
||||||
Default("email"),
|
Default("email"),
|
||||||
field.Time("last_login_at").
|
field.Time("last_login_at").
|
||||||
Optional().
|
Optional().
|
||||||
|
|||||||
@ -211,25 +211,27 @@ type WeChatConnectConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConnectConfig struct {
|
type OIDCConnectConfig struct {
|
||||||
Enabled bool `mapstructure:"enabled"`
|
Enabled bool `mapstructure:"enabled"`
|
||||||
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
|
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
|
||||||
ClientID string `mapstructure:"client_id"`
|
ClientID string `mapstructure:"client_id"`
|
||||||
ClientSecret string `mapstructure:"client_secret"`
|
ClientSecret string `mapstructure:"client_secret"`
|
||||||
IssuerURL string `mapstructure:"issuer_url"`
|
IssuerURL string `mapstructure:"issuer_url"`
|
||||||
DiscoveryURL string `mapstructure:"discovery_url"`
|
DiscoveryURL string `mapstructure:"discovery_url"`
|
||||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||||
TokenURL string `mapstructure:"token_url"`
|
TokenURL string `mapstructure:"token_url"`
|
||||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||||
JWKSURL string `mapstructure:"jwks_url"`
|
JWKSURL string `mapstructure:"jwks_url"`
|
||||||
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
|
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
|
||||||
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
||||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
|
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
|
||||||
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
||||||
UsePKCE bool `mapstructure:"use_pkce"`
|
UsePKCE bool `mapstructure:"use_pkce"`
|
||||||
ValidateIDToken bool `mapstructure:"validate_id_token"`
|
ValidateIDToken bool `mapstructure:"validate_id_token"`
|
||||||
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
|
UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
|
||||||
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
|
ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
|
||||||
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
|
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
|
||||||
|
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
|
||||||
|
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
|
||||||
|
|
||||||
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
||||||
// 为空时,服务端会尝试一组常见字段名。
|
// 为空时,服务端会尝试一组常见字段名。
|
||||||
@ -329,6 +331,14 @@ func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
|
|||||||
return !hasNewEnv
|
return !hasNewEnv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasExplicitConfigOrEnv(configKey, envKey string) bool {
|
||||||
|
if viper.InConfig(configKey) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := os.LookupEnv(envKey)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
|
func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return
|
return
|
||||||
@ -1262,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
|||||||
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
|
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
|
||||||
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
|
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
|
||||||
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
|
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
|
||||||
|
cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
|
||||||
|
cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
|
||||||
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||||
|
|||||||
@ -254,6 +254,21 @@ func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, cfg.OIDC.UsePKCE)
|
require.True(t, cfg.OIDC.UsePKCE)
|
||||||
require.True(t, cfg.OIDC.ValidateIDToken)
|
require.True(t, cfg.OIDC.ValidateIDToken)
|
||||||
|
require.False(t, cfg.OIDC.UsePKCEExplicit)
|
||||||
|
require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
|
||||||
|
t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, cfg.OIDC.UsePKCE)
|
||||||
|
require.False(t, cfg.OIDC.ValidateIDToken)
|
||||||
|
require.True(t, cfg.OIDC.UsePKCEExplicit)
|
||||||
|
require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
|
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
|
||||||
|
|||||||
@ -335,6 +335,75 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
|
|||||||
require.Equal(t, false, data["oidc_connect_validate_id_token"])
|
require.Equal(t, false, data["oidc_connect_validate_id_token"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
repo := &settingHandlerRepoStub{
|
||||||
|
values: map[string]string{
|
||||||
|
service.SettingKeyPromoCodeEnabled: "true",
|
||||||
|
service.SettingKeyOIDCConnectEnabled: "true",
|
||||||
|
service.SettingKeyOIDCConnectProviderName: "OIDC",
|
||||||
|
service.SettingKeyOIDCConnectClientID: "oidc-client",
|
||||||
|
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
|
||||||
|
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
|
||||||
|
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
|
||||||
|
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
|
||||||
|
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
|
||||||
|
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
|
||||||
|
service.SettingKeyOIDCConnectScopes: "openid email profile",
|
||||||
|
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
|
||||||
|
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
|
||||||
|
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
|
||||||
|
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
|
||||||
|
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
|
||||||
|
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
|
||||||
|
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
|
||||||
|
service.SettingKeyOIDCConnectUserInfoIDPath: "",
|
||||||
|
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := service.NewSettingService(repo, &config.Config{
|
||||||
|
Default: config.DefaultConfig{UserConcurrency: 5},
|
||||||
|
OIDC: config.OIDCConnectConfig{
|
||||||
|
Enabled: true,
|
||||||
|
ProviderName: "OIDC",
|
||||||
|
ClientID: "oidc-client",
|
||||||
|
ClientSecret: "oidc-secret",
|
||||||
|
IssuerURL: "https://issuer.example.com",
|
||||||
|
AuthorizeURL: "https://issuer.example.com/auth",
|
||||||
|
TokenURL: "https://issuer.example.com/token",
|
||||||
|
UserInfoURL: "https://issuer.example.com/userinfo",
|
||||||
|
JWKSURL: "https://issuer.example.com/jwks",
|
||||||
|
Scopes: "openid email profile",
|
||||||
|
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
|
||||||
|
FrontendRedirectURL: "/auth/oidc/callback",
|
||||||
|
TokenAuthMethod: "client_secret_post",
|
||||||
|
UsePKCE: true,
|
||||||
|
ValidateIDToken: true,
|
||||||
|
AllowedSigningAlgs: "RS256",
|
||||||
|
ClockSkewSeconds: 120,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
body := map[string]any{
|
||||||
|
"promo_code_enabled": true,
|
||||||
|
"oidc_connect_enabled": true,
|
||||||
|
}
|
||||||
|
rawBody, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
handler.UpdateSettings(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
|
||||||
|
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
|
||||||
|
}
|
||||||
|
|
||||||
func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
|
func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
repo := &settingHandlerRepoStub{
|
repo := &settingHandlerRepoStub{
|
||||||
|
|||||||
@ -355,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
userEntity, err := client.User.Query().
|
userEntity, err := client.User.Query().
|
||||||
Where(dbuser.EmailEqualFold(email)).
|
Where(userNormalizedEmailPredicate(email)).
|
||||||
Only(ctx)
|
Order(dbent.Asc(dbuser.FieldID)).
|
||||||
|
All(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if dbent.IsNotFound(err) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
|
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
|
||||||
}
|
}
|
||||||
return userEntity, nil
|
switch len(userEntity) {
|
||||||
|
case 0:
|
||||||
|
return nil, nil
|
||||||
|
case 1:
|
||||||
|
return userEntity[0], nil
|
||||||
|
default:
|
||||||
|
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
|
func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
|
||||||
@ -411,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
|
|||||||
completionResponse["choice_reason"] = "force_email_on_signup"
|
completionResponse["choice_reason"] = "force_email_on_signup"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var targetUserID *int64
|
||||||
|
if compatEmailUser != nil && compatEmailUser.ID > 0 {
|
||||||
|
targetUserID = &compatEmailUser.ID
|
||||||
|
}
|
||||||
|
|
||||||
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||||
Intent: oauthIntentLogin,
|
Intent: oauthIntentLogin,
|
||||||
Identity: identity,
|
Identity: identity,
|
||||||
|
TargetUserID: targetUserID,
|
||||||
ResolvedEmail: resolvedChoiceEmail,
|
ResolvedEmail: resolvedChoiceEmail,
|
||||||
RedirectTo: redirectTo,
|
RedirectTo: redirectTo,
|
||||||
BrowserSessionKey: browserSessionKey,
|
BrowserSessionKey: browserSessionKey,
|
||||||
@ -490,9 +501,13 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
client := h.entClient()
|
||||||
if err != nil {
|
if client == nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
|
||||||
|
respondPendingOAuthBindingApplyError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||||
@ -503,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
|
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
if err != nil {
|
||||||
return
|
|
||||||
}
|
|
||||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
|
||||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
|
|
||||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
|
||||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
|
||||||
|
respondPendingOAuthBindingApplyError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||||
|
|
||||||
|
|||||||
@ -508,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
existingUser, err := client.User.Create().
|
existingUser, err := client.User.Create().
|
||||||
SetEmail("legacy@example.com").
|
SetEmail(" Legacy@Example.com ").
|
||||||
SetUsername("legacy-user").
|
SetUsername("legacy-user").
|
||||||
SetPasswordHash("hash").
|
SetPasswordHash("hash").
|
||||||
SetRole(service.RoleUser).
|
SetRole(service.RoleUser).
|
||||||
@ -539,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
|
|||||||
Only(ctx)
|
Only(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, oauthIntentLogin, session.Intent)
|
require.Equal(t, oauthIntentLogin, session.Intent)
|
||||||
require.Nil(t, session.TargetUserID)
|
require.NotNil(t, session.TargetUserID)
|
||||||
require.Equal(t, existingUser.Email, session.ResolvedEmail)
|
require.Equal(t, existingUser.ID, *session.TargetUserID)
|
||||||
|
require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
|
||||||
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
|
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
|
||||||
|
|
||||||
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
|
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, "/dashboard", completion["redirect"])
|
require.Equal(t, "/dashboard", completion["redirect"])
|
||||||
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
require.Equal(t, oauthPendingChoiceStep, completion["step"])
|
||||||
require.Equal(t, existingUser.Email, completion["email"])
|
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
|
||||||
require.Equal(t, existingUser.Email, completion["existing_account_email"])
|
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
|
||||||
require.Equal(t, true, completion["existing_account_bindable"])
|
require.Equal(t, true, completion["existing_account_bindable"])
|
||||||
require.Equal(t, "compat_email_match", completion["choice_reason"])
|
require.Equal(t, "compat_email_match", completion["choice_reason"])
|
||||||
_, hasAccessToken := completion["access_token"]
|
_, hasAccessToken := completion["access_token"]
|
||||||
@ -943,6 +944,68 @@ func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *te
|
|||||||
require.False(t, decision.AdoptAvatar)
|
require.False(t, decision.AdoptAvatar)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
existingOwner, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetUsername("owner-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(existingOwner.ID).
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("linuxdo-conflict-subject").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("linuxdo-complete-conflict-session").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("linuxdo-conflict-subject").
|
||||||
|
SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
|
||||||
|
SetBrowserSessionKey("linuxdo-conflict-browser").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"username": "linuxdo_user",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.CompleteLinuxDoOAuthRegistration(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusConflict, recorder.Code)
|
||||||
|
payload := decodeJSONBody(t, recorder)
|
||||||
|
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
|
||||||
|
|
||||||
|
userCount, err := client.User.Query().
|
||||||
|
Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
|
}
|
||||||
|
|
||||||
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
|
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
|
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
|
||||||
|
|||||||
@ -519,7 +519,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
|
|||||||
|
|
||||||
email := strings.TrimSpace(strings.ToLower(req.Email))
|
email := strings.TrimSpace(strings.ToLower(req.Email))
|
||||||
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
|
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
|
||||||
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
|
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@ -704,6 +704,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
|
|||||||
return matches[0], nil
|
return matches[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
|
||||||
|
if client == nil || session == nil {
|
||||||
|
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
|
||||||
|
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
|
||||||
|
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if identity == nil || identity.UserID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if activeOwner != nil {
|
||||||
|
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
|
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
|
||||||
if session == nil {
|
if session == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -1206,6 +1238,38 @@ func consumePendingOAuthBrowserSessionTx(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyPendingOAuthAdoptionAndConsumeSession(
|
||||||
|
ctx context.Context,
|
||||||
|
client *dbent.Client,
|
||||||
|
authService *service.AuthService,
|
||||||
|
userService *service.UserService,
|
||||||
|
session *dbent.PendingAuthSession,
|
||||||
|
decision *dbent.IdentityAdoptionDecision,
|
||||||
|
userID int64,
|
||||||
|
) error {
|
||||||
|
if client == nil {
|
||||||
|
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||||
|
}
|
||||||
|
if session == nil || userID <= 0 {
|
||||||
|
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := client.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
txCtx := dbent.NewTxContext(ctx, tx)
|
||||||
|
if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
func applyPendingOAuthAdoption(
|
func applyPendingOAuthAdoption(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *dbent.Client,
|
client *dbent.Client,
|
||||||
@ -1448,16 +1512,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
|
|||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
client *dbent.Client,
|
client *dbent.Client,
|
||||||
session *dbent.PendingAuthSession,
|
session *dbent.PendingAuthSession,
|
||||||
|
targetUser *dbent.User,
|
||||||
email string,
|
email string,
|
||||||
) (*dbent.PendingAuthSession, error) {
|
) (*dbent.PendingAuthSession, error) {
|
||||||
completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
|
completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
|
||||||
|
var targetUserID *int64
|
||||||
|
if targetUser != nil && targetUser.ID > 0 {
|
||||||
|
targetUserID = &targetUser.ID
|
||||||
|
}
|
||||||
session, err := updatePendingOAuthSessionProgress(
|
session, err := updatePendingOAuthSessionProgress(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
client,
|
client,
|
||||||
session,
|
session,
|
||||||
strings.TrimSpace(session.Intent),
|
strings.TrimSpace(session.Intent),
|
||||||
email,
|
email,
|
||||||
nil,
|
targetUserID,
|
||||||
completionResponse,
|
completionResponse,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1601,7 +1670,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if existingUser != nil {
|
if existingUser != nil {
|
||||||
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
|
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@ -1624,7 +1693,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrEmailExists) {
|
if errors.Is(err, service.ErrEmailExists) {
|
||||||
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
|
existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
|
||||||
|
if lookupErr != nil {
|
||||||
|
response.ErrorFrom(c, lookupErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1045,7 +1045,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
|
|||||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err := client.User.Create().
|
existingUser, err := client.User.Create().
|
||||||
SetEmail("owner@example.com").
|
SetEmail("owner@example.com").
|
||||||
SetUsername("owner-user").
|
SetUsername("owner-user").
|
||||||
SetPasswordHash("hash").
|
SetPasswordHash("hash").
|
||||||
@ -1099,7 +1099,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
|
|||||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, oauthIntentLogin, storedSession.Intent)
|
require.Equal(t, oauthIntentLogin, storedSession.Intent)
|
||||||
require.Nil(t, storedSession.TargetUserID)
|
require.NotNil(t, storedSession.TargetUserID)
|
||||||
|
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
|
||||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
|
|
||||||
@ -1118,7 +1119,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
|
|||||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err := client.User.Create().
|
existingUser, err := client.User.Create().
|
||||||
SetEmail(" Owner@Example.com ").
|
SetEmail(" Owner@Example.com ").
|
||||||
SetUsername("owner-user").
|
SetUsername("owner-user").
|
||||||
SetPasswordHash("hash").
|
SetPasswordHash("hash").
|
||||||
@ -1164,7 +1165,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
|
|||||||
|
|
||||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, storedSession.TargetUserID)
|
require.NotNil(t, storedSession.TargetUserID)
|
||||||
|
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
|
||||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1172,7 +1174,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
|
|||||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err := client.User.Create().
|
existingUser, err := client.User.Create().
|
||||||
SetEmail("owner@example.com").
|
SetEmail("owner@example.com").
|
||||||
SetUsername("owner-user").
|
SetUsername("owner-user").
|
||||||
SetPasswordHash("hash").
|
SetPasswordHash("hash").
|
||||||
@ -1220,7 +1222,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
|
|||||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, oauthIntentLogin, storedSession.Intent)
|
require.Equal(t, oauthIntentLogin, storedSession.Intent)
|
||||||
require.Nil(t, storedSession.TargetUserID)
|
require.NotNil(t, storedSession.TargetUserID)
|
||||||
|
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
|
||||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -563,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
|
|||||||
if compatEmailUser != nil {
|
if compatEmailUser != nil {
|
||||||
resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
|
resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
|
||||||
}
|
}
|
||||||
|
var targetUserID *int64
|
||||||
|
if compatEmailUser != nil && compatEmailUser.ID > 0 {
|
||||||
|
targetUserID = &compatEmailUser.ID
|
||||||
|
}
|
||||||
|
|
||||||
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||||
Intent: oauthIntentLogin,
|
Intent: oauthIntentLogin,
|
||||||
Identity: identity,
|
Identity: identity,
|
||||||
|
TargetUserID: targetUserID,
|
||||||
ResolvedEmail: resolvedChoiceEmail,
|
ResolvedEmail: resolvedChoiceEmail,
|
||||||
RedirectTo: redirectTo,
|
RedirectTo: redirectTo,
|
||||||
BrowserSessionKey: browserSessionKey,
|
BrowserSessionKey: browserSessionKey,
|
||||||
@ -643,9 +648,13 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
client := h.entClient()
|
||||||
if err != nil {
|
if client == nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
|
||||||
|
respondPendingOAuthBindingApplyError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||||
@ -656,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
|
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
if err != nil {
|
||||||
return
|
|
||||||
}
|
|
||||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
|
||||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
|
|
||||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
|
||||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
|
||||||
|
respondPendingOAuthBindingApplyError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||||
|
|
||||||
|
|||||||
@ -438,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
|
|||||||
Only(ctx)
|
Only(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, oauthIntentLogin, session.Intent)
|
require.Equal(t, oauthIntentLogin, session.Intent)
|
||||||
require.Nil(t, session.TargetUserID)
|
require.NotNil(t, session.TargetUserID)
|
||||||
|
require.Equal(t, existingUser.ID, *session.TargetUserID)
|
||||||
require.Equal(t, existingUser.Email, session.ResolvedEmail)
|
require.Equal(t, existingUser.Email, session.ResolvedEmail)
|
||||||
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
|
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
|
||||||
|
|
||||||
@ -862,6 +863,69 @@ func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testi
|
|||||||
require.False(t, decision.AdoptAvatar)
|
require.False(t, decision.AdoptAvatar)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
existingOwner, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetUsername("owner-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(existingOwner.ID).
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example.com").
|
||||||
|
SetProviderSubject("oidc-conflict-subject").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("oidc-complete-conflict-session").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example.com").
|
||||||
|
SetProviderSubject("oidc-conflict-subject").
|
||||||
|
SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
|
||||||
|
SetBrowserSessionKey("oidc-conflict-browser").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"username": "oidc_user",
|
||||||
|
"issuer": "https://issuer.example.com",
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
handler.CompleteOIDCOAuthRegistration(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusConflict, recorder.Code)
|
||||||
|
payload := decodeJSONBody(t, recorder)
|
||||||
|
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
|
||||||
|
|
||||||
|
userCount, err := client.User.Query().
|
||||||
|
Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
|
}
|
||||||
|
|
||||||
type oidcProviderFixture struct {
|
type oidcProviderFixture struct {
|
||||||
Subject string
|
Subject string
|
||||||
PreferredUsername string
|
PreferredUsername string
|
||||||
|
|||||||
@ -576,6 +576,258 @@ FROM auth_identity_migration_reports
|
|||||||
require.Equal(t, beforeCount, afterCount)
|
require.Equal(t, beforeCount, afterCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
|
||||||
|
tx := testTx(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
|
||||||
|
migrationSQL, err := os.ReadFile(migrationPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
|
||||||
|
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
|
||||||
|
|
||||||
|
var linuxDoFirstUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&linuxDoFirstUserID))
|
||||||
|
|
||||||
|
var linuxDoSecondUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&linuxDoSecondUserID))
|
||||||
|
|
||||||
|
var wechatFirstUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&wechatFirstUserID))
|
||||||
|
|
||||||
|
var wechatSecondUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&wechatSecondUserID))
|
||||||
|
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
|
||||||
|
RETURNING id
|
||||||
|
`, linuxDoFirstUserID).Scan(new(int64)))
|
||||||
|
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
|
||||||
|
RETURNING id
|
||||||
|
`, linuxDoSecondUserID).Scan(new(int64)))
|
||||||
|
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
|
||||||
|
RETURNING id
|
||||||
|
`, wechatFirstUserID).Scan(new(int64)))
|
||||||
|
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
|
||||||
|
RETURNING id
|
||||||
|
`, wechatSecondUserID).Scan(new(int64)))
|
||||||
|
|
||||||
|
_, err = tx.ExecContext(ctx, string(migrationSQL))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var linuxDoIdentityCount int
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM auth_identities
|
||||||
|
WHERE provider_type = 'linuxdo'
|
||||||
|
AND provider_key = 'linuxdo'
|
||||||
|
AND provider_subject = 'linuxdo-ambiguous-subject'
|
||||||
|
`).Scan(&linuxDoIdentityCount))
|
||||||
|
require.Zero(t, linuxDoIdentityCount)
|
||||||
|
|
||||||
|
var wechatIdentityCount int
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM auth_identities
|
||||||
|
WHERE provider_type = 'wechat'
|
||||||
|
AND provider_key = 'wechat-main'
|
||||||
|
AND provider_subject = 'union-ambiguous-subject'
|
||||||
|
`).Scan(&wechatIdentityCount))
|
||||||
|
require.Zero(t, wechatIdentityCount)
|
||||||
|
|
||||||
|
var wechatChannelCount int
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM auth_identity_channels
|
||||||
|
WHERE provider_type = 'wechat'
|
||||||
|
AND provider_key = 'wechat-main'
|
||||||
|
AND channel = 'oa'
|
||||||
|
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
|
||||||
|
`).Scan(&wechatChannelCount))
|
||||||
|
require.Zero(t, wechatChannelCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
|
||||||
|
tx := testTx(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
|
||||||
|
migration115SQL, err := os.ReadFile(migration115Path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
|
||||||
|
migration116SQL, err := os.ReadFile(migration116Path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
|
||||||
|
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
|
||||||
|
|
||||||
|
var linuxDoFirstUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&linuxDoFirstUserID))
|
||||||
|
|
||||||
|
var linuxDoSecondUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&linuxDoSecondUserID))
|
||||||
|
|
||||||
|
var wechatFirstUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&wechatFirstUserID))
|
||||||
|
|
||||||
|
var wechatSecondUserID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||||
|
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
|
||||||
|
RETURNING id`).Scan(&wechatSecondUserID))
|
||||||
|
|
||||||
|
var linuxDoFirstLegacyID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
|
||||||
|
RETURNING id
|
||||||
|
`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
|
||||||
|
|
||||||
|
var linuxDoSecondLegacyID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
|
||||||
|
RETURNING id
|
||||||
|
`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
|
||||||
|
|
||||||
|
var wechatFirstLegacyID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
|
||||||
|
RETURNING id
|
||||||
|
`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
|
||||||
|
|
||||||
|
var wechatSecondLegacyID int64
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO user_external_identities (
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
provider_union_id,
|
||||||
|
provider_username,
|
||||||
|
display_name,
|
||||||
|
metadata
|
||||||
|
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
|
||||||
|
RETURNING id
|
||||||
|
`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
|
||||||
|
|
||||||
|
_, err = tx.ExecContext(ctx, string(migration115SQL))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = tx.ExecContext(ctx, string(migration116SQL))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var identityCount int
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM auth_identities
|
||||||
|
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
|
||||||
|
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
|
||||||
|
`).Scan(&identityCount))
|
||||||
|
require.Zero(t, identityCount)
|
||||||
|
|
||||||
|
var conflictReportCount int
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM auth_identity_migration_reports
|
||||||
|
WHERE report_type = 'legacy_external_identity_conflict'
|
||||||
|
AND report_key IN ($1, $2, $3, $4)
|
||||||
|
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
|
||||||
|
require.Equal(t, 4, conflictReportCount)
|
||||||
|
|
||||||
|
var winnerAttributedReportCount int
|
||||||
|
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM auth_identity_migration_reports
|
||||||
|
WHERE report_type = 'legacy_external_identity_conflict'
|
||||||
|
AND report_key IN ($1, $2, $3, $4)
|
||||||
|
AND details ->> 'existing_identity_id' IS NOT NULL
|
||||||
|
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
|
||||||
|
require.Zero(t, winnerAttributedReportCount)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
|
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
|
||||||
tx := testTx(t)
|
tx := testTx(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
|||||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||||
const migrationsLockRetryInterval = 500 * time.Millisecond
|
const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||||
const nonTransactionalMigrationSuffix = "_notx.sql"
|
const nonTransactionalMigrationSuffix = "_notx.sql"
|
||||||
|
const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
|
||||||
|
const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
|
||||||
|
|
||||||
type migrationChecksumCompatibilityRule struct {
|
type migrationChecksumCompatibilityRule struct {
|
||||||
fileChecksum string
|
fileChecksum string
|
||||||
@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
|
|||||||
"054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
|
"054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
|
||||||
"061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
|
"061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
|
||||||
"109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
|
"109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
|
||||||
|
"115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
|
||||||
|
"116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
|
||||||
"118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"),
|
"118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"),
|
||||||
"119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
|
"119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
|
||||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
|
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
|
||||||
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
|
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nonTx {
|
if nonTx {
|
||||||
|
if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
|
||||||
|
return fmt.Errorf("prepare migration %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
|
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
|
||||||
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
|
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
|
||||||
statements := splitSQLStatements(content)
|
statements := splitSQLStatements(content)
|
||||||
@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
|
||||||
|
switch name {
|
||||||
|
case paymentOrdersOutTradeNoUniqueMigration:
|
||||||
|
return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
|
||||||
|
duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
|
||||||
|
}
|
||||||
|
if len(duplicates) > 0 {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
|
||||||
|
paymentOrdersOutTradeNoUniqueMigration,
|
||||||
|
strings.Join(duplicates, ", "),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
|
||||||
|
}
|
||||||
|
if !invalid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
|
||||||
|
return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
|
||||||
|
rows, err := db.QueryContext(ctx, `
|
||||||
|
SELECT out_trade_no, COUNT(*) AS duplicate_count
|
||||||
|
FROM payment_orders
|
||||||
|
WHERE out_trade_no <> ''
|
||||||
|
GROUP BY out_trade_no
|
||||||
|
HAVING COUNT(*) > 1
|
||||||
|
ORDER BY duplicate_count DESC, out_trade_no
|
||||||
|
LIMIT 5
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
duplicates := make([]string, 0, 5)
|
||||||
|
for rows.Next() {
|
||||||
|
var outTradeNo string
|
||||||
|
var duplicateCount int
|
||||||
|
if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return duplicates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
|
||||||
|
var invalid bool
|
||||||
|
err := db.QueryRowContext(ctx, `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_class idx
|
||||||
|
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
|
||||||
|
JOIN pg_index i ON i.indexrelid = idx.oid
|
||||||
|
WHERE ns.nspname = 'public'
|
||||||
|
AND idx.relname = $1
|
||||||
|
AND NOT i.indisvalid
|
||||||
|
)
|
||||||
|
`, indexName).Scan(&invalid)
|
||||||
|
return invalid, err
|
||||||
|
}
|
||||||
|
|
||||||
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||||
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
|
||||||
|
ok := isMigrationChecksumCompatible(
|
||||||
|
"115_auth_identity_legacy_external_backfill.sql",
|
||||||
|
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
|
||||||
|
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
|
||||||
|
)
|
||||||
|
require.True(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
|
||||||
|
ok := isMigrationChecksumCompatible(
|
||||||
|
"116_auth_identity_legacy_external_safety_reports.sql",
|
||||||
|
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
|
||||||
|
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
|
||||||
|
)
|
||||||
|
require.True(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
|
t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
|
||||||
ok := isMigrationChecksumCompatible(
|
ok := isMigrationChecksumCompatible(
|
||||||
"119_enforce_payment_orders_out_trade_no_unique.sql",
|
"119_enforce_payment_orders_out_trade_no_unique.sql",
|
||||||
@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
|
||||||
|
for _, dbChecksum := range []string{
|
||||||
|
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
|
||||||
|
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
|
||||||
|
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
|
||||||
|
} {
|
||||||
|
ok := isMigrationChecksumCompatible(
|
||||||
|
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
|
||||||
|
dbChecksum,
|
||||||
|
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
|
||||||
|
)
|
||||||
|
require.True(t, ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("119未知checksum不兼容", func(t *testing.T) {
|
t.Run("119未知checksum不兼容", func(t *testing.T) {
|
||||||
ok := isMigrationChecksumCompatible(
|
ok := isMigrationChecksumCompatible(
|
||||||
"119_enforce_payment_orders_out_trade_no_unique.sql",
|
"119_enforce_payment_orders_out_trade_no_unique.sql",
|
||||||
|
|||||||
@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
|||||||
|
|
||||||
func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
|
func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
|
||||||
for _, name := range []string{
|
for _, name := range []string{
|
||||||
|
"115_auth_identity_legacy_external_backfill.sql",
|
||||||
|
"116_auth_identity_legacy_external_safety_reports.sql",
|
||||||
"118_wechat_dual_mode_and_auth_source_defaults.sql",
|
"118_wechat_dual_mode_and_auth_source_defaults.sql",
|
||||||
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
|
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
|
||||||
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
|
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
|
||||||
|
|||||||
@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
|
|||||||
require.NoError(t, mock.ExpectationsWereMet())
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = db.Close() }()
|
||||||
|
|
||||||
|
prepareMigrationsBootstrapExpectations(mock)
|
||||||
|
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||||
|
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
|
||||||
|
WillReturnError(sql.ErrNoRows)
|
||||||
|
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
|
||||||
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||||
|
WithArgs(migrationsAdvisoryLockID).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
|
||||||
|
Data: []byte(`
|
||||||
|
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
|
||||||
|
ON payment_orders (out_trade_no)
|
||||||
|
WHERE out_trade_no <> '';
|
||||||
|
|
||||||
|
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
|
||||||
|
`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "duplicate out_trade_no")
|
||||||
|
require.Contains(t, err.Error(), "dup-out-trade-no")
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = db.Close() }()
|
||||||
|
|
||||||
|
prepareMigrationsBootstrapExpectations(mock)
|
||||||
|
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||||
|
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
|
||||||
|
WillReturnError(sql.ErrNoRows)
|
||||||
|
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
|
||||||
|
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||||
|
WithArgs("paymentorder_out_trade_no_unique").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||||
|
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||||
|
mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||||
|
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||||
|
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||||
|
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
|
||||||
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||||
|
WithArgs(migrationsAdvisoryLockID).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
fsys := fstest.MapFS{
|
||||||
|
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
|
||||||
|
Data: []byte(`
|
||||||
|
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
|
||||||
|
ON payment_orders (out_trade_no)
|
||||||
|
WHERE out_trade_no <> '';
|
||||||
|
|
||||||
|
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
|
||||||
|
`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
|
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T)
|
|||||||
tx := testTx(t)
|
tx := testTx(t)
|
||||||
|
|
||||||
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
|
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
|
||||||
|
requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
|
||||||
|
requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
|
||||||
|
requireConstraintDefinitionContains(
|
||||||
|
t,
|
||||||
|
tx,
|
||||||
|
"users",
|
||||||
|
"users_signup_source_check",
|
||||||
|
"signup_source",
|
||||||
|
"'email'",
|
||||||
|
"'linuxdo'",
|
||||||
|
"'wechat'",
|
||||||
|
"'oidc'",
|
||||||
|
)
|
||||||
|
|
||||||
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
|
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
|
||||||
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
|
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
|
||||||
@ -195,6 +208,45 @@ LIMIT 1
|
|||||||
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
|
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var def string
|
||||||
|
err := tx.QueryRowContext(context.Background(), `
|
||||||
|
SELECT pg_get_constraintdef(c.oid)
|
||||||
|
FROM pg_constraint c
|
||||||
|
JOIN pg_class tbl ON tbl.oid = c.conrelid
|
||||||
|
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
|
||||||
|
WHERE ns.nspname = 'public'
|
||||||
|
AND tbl.relname = $1
|
||||||
|
AND c.conname = $2
|
||||||
|
`, table, constraint).Scan(&def)
|
||||||
|
require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
|
||||||
|
|
||||||
|
for _, fragment := range fragments {
|
||||||
|
require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var columnDefault sql.NullString
|
||||||
|
err := tx.QueryRowContext(context.Background(), `
|
||||||
|
SELECT column_default
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = $1
|
||||||
|
AND column_name = $2
|
||||||
|
`, table, column).Scan(&columnDefault)
|
||||||
|
require.NoError(t, err, "query column_default for %s.%s", table, column)
|
||||||
|
require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
|
||||||
|
|
||||||
|
for _, fragment := range fragments {
|
||||||
|
require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@ -4,11 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
entsql "entgo.io/ent/dialect/sql"
|
entsql "entgo.io/ent/dialect/sql"
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
|
|||||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
|
||||||
|
|
||||||
|
type scopedKeyLockRegistry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
locks map[string]*scopedKeyLockEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type scopedKeyLockEntry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
refs int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
|
||||||
|
return &scopedKeyLockRegistry{
|
||||||
|
locks: make(map[string]*scopedKeyLockEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
|
||||||
|
normalized := normalizeLockKeys(keys...)
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]*scopedKeyLockEntry, 0, len(normalized))
|
||||||
|
r.mu.Lock()
|
||||||
|
for _, key := range normalized {
|
||||||
|
entry := r.locks[key]
|
||||||
|
if entry == nil {
|
||||||
|
entry = &scopedKeyLockEntry{}
|
||||||
|
r.locks[key] = entry
|
||||||
|
}
|
||||||
|
entry.refs++
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
entry.mu.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
for i := len(entries) - 1; i >= 0; i-- {
|
||||||
|
entries[i].mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
for idx, key := range normalized {
|
||||||
|
entry := entries[idx]
|
||||||
|
entry.refs--
|
||||||
|
if entry.refs == 0 {
|
||||||
|
delete(r.locks, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeLockKeys(keys ...string) []string {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deduped := make(map[string]struct{}, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
trimmed := strings.TrimSpace(key)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deduped[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(deduped) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := make([]string, 0, len(deduped))
|
||||||
|
for key := range deduped {
|
||||||
|
normalized = append(normalized, key)
|
||||||
|
}
|
||||||
|
sort.Strings(normalized)
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func advisoryLockHash(key string) int64 {
|
||||||
|
hasher := fnv.New64a()
|
||||||
|
_, _ = hasher.Write([]byte(key))
|
||||||
|
return int64(hasher.Sum64())
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
|
||||||
|
release := repositoryScopedKeyLocks.lock(keys...)
|
||||||
|
normalized := normalizeLockKeys(keys...)
|
||||||
|
if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
|
||||||
|
return release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range normalized {
|
||||||
|
rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
|
||||||
|
if err != nil {
|
||||||
|
release()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = rows.Close()
|
||||||
|
}
|
||||||
|
return release, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
||||||
if dbent.TxFromContext(ctx) != nil {
|
if dbent.TxFromContext(ctx) != nil {
|
||||||
return fn(ctx)
|
return fn(ctx)
|
||||||
@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
|
||||||
update := client.AuthIdentity.UpdateOneID(identity.ID)
|
update := client.AuthIdentity.UpdateOneID(identity.ID)
|
||||||
|
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
|
||||||
|
update = update.SetProviderKey(targetProviderKey)
|
||||||
|
}
|
||||||
if input.Metadata != nil {
|
if input.Metadata != nil {
|
||||||
update = update.SetMetadata(copyMetadata(input.Metadata))
|
update = update.SetMetadata(copyMetadata(input.Metadata))
|
||||||
}
|
}
|
||||||
@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
|
||||||
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
|
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
|
||||||
SetIdentityID(identity.ID)
|
SetIdentityID(identity.ID)
|
||||||
|
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
|
||||||
|
update = update.SetProviderKey(targetProviderKey)
|
||||||
|
}
|
||||||
if input.ChannelMetadata != nil {
|
if input.ChannelMetadata != nil {
|
||||||
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
|
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
|
||||||
}
|
}
|
||||||
@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
|
||||||
|
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||||
|
existingKey = strings.TrimSpace(existingKey)
|
||||||
|
requestedKey = strings.TrimSpace(requestedKey)
|
||||||
|
if providerType != "wechat" {
|
||||||
|
if requestedKey != "" {
|
||||||
|
return requestedKey
|
||||||
|
}
|
||||||
|
return existingKey
|
||||||
|
}
|
||||||
|
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
|
||||||
|
return "wechat-main"
|
||||||
|
}
|
||||||
|
if requestedKey != "" {
|
||||||
|
return requestedKey
|
||||||
|
}
|
||||||
|
return existingKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
|
||||||
|
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||||
|
providerKey = strings.TrimSpace(providerKey)
|
||||||
|
if providerType != "wechat" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.EqualFold(providerKey, "wechat-main"):
|
||||||
|
return 0
|
||||||
|
case strings.EqualFold(providerKey, "wechat"):
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||||
|
var selected *dbent.AuthIdentity
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
if record.UserID == userID {
|
if record.UserID != userID {
|
||||||
return record
|
continue
|
||||||
|
}
|
||||||
|
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||||
|
selected = record
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return selected
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||||
@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||||
|
var selected *dbent.AuthIdentityChannel
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID {
|
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
|
||||||
return record
|
continue
|
||||||
|
}
|
||||||
|
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||||
|
selected = record
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return selected
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||||
@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
|
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
|
||||||
client := clientFromContext(ctx, r.client)
|
var result *dbent.IdentityAdoptionDecision
|
||||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
|
||||||
if _, err := client.IdentityAdoptionDecision.Update().
|
client := clientFromContext(txCtx, r.client)
|
||||||
Where(
|
releaseLocks, err := lockRepositoryScopedKeys(
|
||||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
txCtx,
|
||||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
client,
|
||||||
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||||
s.Where(entsql.Or(
|
identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
|
||||||
entsql.IsNull(col),
|
)
|
||||||
entsql.NEQ(col, input.PendingAuthSessionID),
|
if err != nil {
|
||||||
))
|
return err
|
||||||
}),
|
}
|
||||||
).
|
defer releaseLocks()
|
||||||
ClearIdentityID().
|
|
||||||
Save(ctx); err != nil {
|
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||||
return nil, err
|
if _, err := client.IdentityAdoptionDecision.Update().
|
||||||
|
Where(
|
||||||
|
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||||
|
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||||
|
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
||||||
|
s.Where(entsql.Or(
|
||||||
|
entsql.IsNull(col),
|
||||||
|
entsql.NEQ(col, input.PendingAuthSessionID),
|
||||||
|
))
|
||||||
|
}),
|
||||||
|
).
|
||||||
|
ClearIdentityID().
|
||||||
|
Save(txCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
current, err := client.IdentityAdoptionDecision.Query().
|
|
||||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
|
||||||
Only(ctx)
|
|
||||||
if err != nil && !dbent.IsNotFound(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
now := time.Now().UTC()
|
|
||||||
if current == nil {
|
|
||||||
create := client.IdentityAdoptionDecision.Create().
|
create := client.IdentityAdoptionDecision.Create().
|
||||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||||
SetAdoptAvatar(input.AdoptAvatar).
|
SetAdoptAvatar(input.AdoptAvatar).
|
||||||
SetDecidedAt(now)
|
SetDecidedAt(time.Now().UTC())
|
||||||
if input.IdentityID != nil {
|
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||||
create = create.SetIdentityID(*input.IdentityID)
|
create = create.SetIdentityID(*input.IdentityID)
|
||||||
}
|
}
|
||||||
return create.Save(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
|
decisionID, err := create.
|
||||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
|
||||||
SetAdoptAvatar(input.AdoptAvatar)
|
UpdateNewValues().
|
||||||
if input.IdentityID != nil {
|
ID(txCtx)
|
||||||
update = update.SetIdentityID(*input.IdentityID)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return update.Save(ctx)
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
|
||||||
|
keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
|
||||||
|
if identityID != nil && *identityID > 0 {
|
||||||
|
keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
|
||||||
|
}
|
||||||
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
|
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
|
||||||
|
|||||||
@ -0,0 +1,212 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
|
||||||
|
repo, client := newUserEntRepo(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user := &service.User{
|
||||||
|
Email: "wechat-legacy@example.com",
|
||||||
|
Username: "wechat-legacy",
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
}
|
||||||
|
require.NoError(t, repo.Create(ctx, user))
|
||||||
|
|
||||||
|
legacyIdentity, err := client.AuthIdentity.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey("wechat").
|
||||||
|
SetProviderSubject("union-legacy-123").
|
||||||
|
SetMetadata(map[string]any{"source": "legacy-alias"}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
legacyChannel, err := client.AuthIdentityChannel.Create().
|
||||||
|
SetIdentityID(legacyIdentity.ID).
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey("wechat").
|
||||||
|
SetChannel("oa").
|
||||||
|
SetChannelAppID("wx-app-legacy").
|
||||||
|
SetChannelSubject("openid-legacy-123").
|
||||||
|
SetMetadata(map[string]any{"scene": "legacy-alias"}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
|
||||||
|
UserID: user.ID,
|
||||||
|
Canonical: AuthIdentityKey{
|
||||||
|
ProviderType: "wechat",
|
||||||
|
ProviderKey: "wechat-main",
|
||||||
|
ProviderSubject: "union-legacy-123",
|
||||||
|
},
|
||||||
|
Channel: &AuthIdentityChannelKey{
|
||||||
|
ProviderType: "wechat",
|
||||||
|
ProviderKey: "wechat-main",
|
||||||
|
Channel: "oa",
|
||||||
|
ChannelAppID: "wx-app-legacy",
|
||||||
|
ChannelSubject: "openid-legacy-123",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{"source": "canonical-bind"},
|
||||||
|
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, bound)
|
||||||
|
require.NotNil(t, bound.Identity)
|
||||||
|
require.NotNil(t, bound.Channel)
|
||||||
|
require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
|
||||||
|
require.Equal(t, legacyChannel.ID, bound.Channel.ID)
|
||||||
|
require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
|
||||||
|
require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
|
||||||
|
|
||||||
|
reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
|
||||||
|
require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
|
||||||
|
|
||||||
|
reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
|
||||||
|
require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
|
||||||
|
|
||||||
|
identityCount, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.UserIDEQ(user.ID),
|
||||||
|
authidentity.ProviderTypeEQ("wechat"),
|
||||||
|
authidentity.ProviderSubjectEQ("union-legacy-123"),
|
||||||
|
).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, identityCount)
|
||||||
|
|
||||||
|
channelCount, err := client.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||||
|
authidentitychannel.ChannelEQ("oa"),
|
||||||
|
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
|
||||||
|
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
|
||||||
|
).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, channelCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
|
||||||
|
repo, client := newUserEntRepo(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user := &service.User{
|
||||||
|
Email: "repo-adoption@example.com",
|
||||||
|
Username: "repo-adoption",
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
}
|
||||||
|
require.NoError(t, repo.Create(ctx, user))
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey("wechat-main").
|
||||||
|
SetProviderSubject("union-repo-adoption").
|
||||||
|
SetMetadata(map[string]any{}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("pending-repo-adoption").
|
||||||
|
SetIntent("bind_current_user").
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey("wechat-main").
|
||||||
|
SetProviderSubject("union-repo-adoption").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
|
||||||
|
SetLocalFlowState(map[string]any{"step": "pending"}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
firstCreateStarted := make(chan struct{})
|
||||||
|
releaseFirstCreate := make(chan struct{})
|
||||||
|
var firstCreate sync.Once
|
||||||
|
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
|
||||||
|
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
|
||||||
|
blocked := false
|
||||||
|
if m.Op().Is(dbent.OpCreate) {
|
||||||
|
firstCreate.Do(func() {
|
||||||
|
blocked = true
|
||||||
|
close(firstCreateStarted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-releaseFirstCreate
|
||||||
|
}
|
||||||
|
return next.Mutate(ctx, m)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
type adoptionResult struct {
|
||||||
|
decision *dbent.IdentityAdoptionDecision
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
input := IdentityAdoptionDecisionInput{
|
||||||
|
PendingAuthSessionID: session.ID,
|
||||||
|
IdentityID: &identity.ID,
|
||||||
|
AdoptDisplayName: true,
|
||||||
|
AdoptAvatar: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make(chan adoptionResult, 2)
|
||||||
|
go func() {
|
||||||
|
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
|
||||||
|
results <- adoptionResult{decision: decision, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-firstCreateStarted
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
|
||||||
|
results <- adoptionResult{decision: decision, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
close(releaseFirstCreate)
|
||||||
|
|
||||||
|
first := <-results
|
||||||
|
second := <-results
|
||||||
|
|
||||||
|
require.NoError(t, first.err)
|
||||||
|
require.NoError(t, second.err)
|
||||||
|
require.NotNil(t, first.decision)
|
||||||
|
require.NotNil(t, second.decision)
|
||||||
|
require.Equal(t, first.decision.ID, second.decision.ID)
|
||||||
|
|
||||||
|
count, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
loaded, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, loaded.IdentityID)
|
||||||
|
require.Equal(t, identity.ID, *loaded.IdentityID)
|
||||||
|
require.True(t, loaded.AdoptDisplayName)
|
||||||
|
require.True(t, loaded.AdoptAvatar)
|
||||||
|
}
|
||||||
@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
if userIn == nil {
|
if userIn == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||||
@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
var txClient *dbent.Client
|
var txClient *dbent.Client
|
||||||
|
txCtx := ctx
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
txClient = tx.Client()
|
txClient = tx.Client()
|
||||||
|
txCtx = dbent.NewTxContext(ctx, tx)
|
||||||
} else {
|
} else {
|
||||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||||
@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
releaseEmailLock, err := lockRepositoryScopedKeys(
|
||||||
|
txCtx,
|
||||||
|
txClient,
|
||||||
|
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||||
|
normalizedEmailUniquenessLockKey(userIn.Email),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer releaseEmailLock()
|
||||||
|
|
||||||
|
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
created, err := txClient.User.Create().
|
created, err := txClient.User.Create().
|
||||||
SetEmail(userIn.Email).
|
SetEmail(userIn.Email).
|
||||||
SetUsername(userIn.Username).
|
SetUsername(userIn.Username).
|
||||||
@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
|||||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||||
Save(ctx)
|
Save(txCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
if userIn == nil {
|
if userIn == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||||
tx, err := r.client.Tx(ctx)
|
tx, err := r.client.Tx(ctx)
|
||||||
@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
var txClient *dbent.Client
|
var txClient *dbent.Client
|
||||||
|
txCtx := ctx
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer func() { _ = tx.Rollback() }()
|
defer func() { _ = tx.Rollback() }()
|
||||||
txClient = tx.Client()
|
txClient = tx.Client()
|
||||||
|
txCtx = dbent.NewTxContext(ctx, tx)
|
||||||
} else {
|
} else {
|
||||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||||
@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
txClient = r.client
|
txClient = r.client
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
|
||||||
|
releaseEmailLock, err := lockRepositoryScopedKeys(
|
||||||
|
txCtx,
|
||||||
|
txClient,
|
||||||
|
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||||
|
normalizedEmailUniquenessLockKey(userIn.Email),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer releaseEmailLock()
|
||||||
|
|
||||||
|
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||||
}
|
}
|
||||||
@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
if userIn.BalanceNotifyThreshold == nil {
|
if userIn.BalanceNotifyThreshold == nil {
|
||||||
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
||||||
}
|
}
|
||||||
updated, err := updateOp.Save(ctx)
|
updated, err := updateOp.Save(txCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
|
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
|
||||||
matches, err := r.client.User.Query().
|
return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
|
||||||
|
client = clientFromContext(ctx, client)
|
||||||
|
if client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
matches, err := client.User.Query().
|
||||||
Where(userEmailLookupPredicate(email)).
|
Where(userEmailLookupPredicate(email)).
|
||||||
All(ctx)
|
All(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use
|
|||||||
}
|
}
|
||||||
|
|
||||||
func userEmailLookupPredicate(email string) predicate.User {
|
func userEmailLookupPredicate(email string) predicate.User {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
normalized := normalizeEmailLookupValue(email)
|
||||||
if normalized == "" {
|
if normalized == "" {
|
||||||
return dbuser.EmailEQ(email)
|
return dbuser.EmailEQ(email)
|
||||||
}
|
}
|
||||||
@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeEmailLookupValue(email string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(email))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizedEmailUniquenessLockKey(email string) string {
|
||||||
|
normalized := normalizeEmailLookupValue(email)
|
||||||
|
if normalized == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "users:normalized-email:" + normalized
|
||||||
|
}
|
||||||
|
|
||||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
err := client.UserAllowedGroup.Create().
|
err := client.UserAllowedGroup.Create().
|
||||||
@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func userSignupSourceOrDefault(signupSource string) string {
|
func userSignupSourceOrDefault(signupSource string) string {
|
||||||
signupSource = strings.TrimSpace(signupSource)
|
switch strings.TrimSpace(strings.ToLower(signupSource)) {
|
||||||
if signupSource == "" {
|
case "", "email":
|
||||||
|
return "email"
|
||||||
|
case "linuxdo", "wechat", "oidc":
|
||||||
|
return strings.TrimSpace(strings.ToLower(signupSource))
|
||||||
|
default:
|
||||||
return "email"
|
return "email"
|
||||||
}
|
}
|
||||||
return signupSource
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
||||||
|
|||||||
@ -3,7 +3,10 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||||
@ -18,9 +21,10 @@ import (
|
|||||||
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
|
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
|
db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { _ = db.Close() })
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
db.SetMaxOpenConns(10)
|
||||||
|
|
||||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -144,3 +148,80 @@ func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorContains(t, err, "normalized email lookup matched multiple users")
|
require.ErrorContains(t, err, "normalized email lookup matched multiple users")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
|
||||||
|
repo, client := newUserEntRepo(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
firstCreateStarted := make(chan struct{})
|
||||||
|
releaseFirstCreate := make(chan struct{})
|
||||||
|
var firstCreate sync.Once
|
||||||
|
client.User.Use(func(next dbent.Mutator) dbent.Mutator {
|
||||||
|
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
|
||||||
|
blocked := false
|
||||||
|
if m.Op().Is(dbent.OpCreate) {
|
||||||
|
firstCreate.Do(func() {
|
||||||
|
blocked = true
|
||||||
|
close(firstCreateStarted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-releaseFirstCreate
|
||||||
|
}
|
||||||
|
return next.Mutate(ctx, m)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
type createResult struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make(chan createResult, 2)
|
||||||
|
go func() {
|
||||||
|
results <- createResult{err: repo.Create(ctx, &service.User{
|
||||||
|
Email: " Race@Example.com ",
|
||||||
|
Username: "race-user-1",
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
})}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-firstCreateStarted
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
results <- createResult{err: repo.Create(ctx, &service.User{
|
||||||
|
Email: "race@example.com",
|
||||||
|
Username: "race-user-2",
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
})}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
close(releaseFirstCreate)
|
||||||
|
|
||||||
|
first := <-results
|
||||||
|
second := <-results
|
||||||
|
|
||||||
|
errors := []error{first.err, second.err}
|
||||||
|
successes := 0
|
||||||
|
conflicts := 0
|
||||||
|
for _, err := range errors {
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
successes++
|
||||||
|
case err == service.ErrEmailExists:
|
||||||
|
conflicts++
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected create error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, 1, successes)
|
||||||
|
require.Equal(t, 1, conflicts)
|
||||||
|
|
||||||
|
count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|||||||
@ -14,10 +14,14 @@ import (
|
|||||||
|
|
||||||
func normalizeOAuthSignupSource(signupSource string) string {
|
func normalizeOAuthSignupSource(signupSource string) string {
|
||||||
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
||||||
if signupSource == "" {
|
switch signupSource {
|
||||||
|
case "", "email":
|
||||||
|
return "email"
|
||||||
|
case "linuxdo", "wechat", "oidc":
|
||||||
|
return signupSource
|
||||||
|
default:
|
||||||
return "email"
|
return "email"
|
||||||
}
|
}
|
||||||
return signupSource
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
||||||
@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
|||||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
signupSource = normalizeOAuthSignupSource(signupSource)
|
||||||
if signupSource == "" {
|
|
||||||
signupSource = "email"
|
|
||||||
}
|
|
||||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||||
|
|
||||||
user := &User{
|
user := &User{
|
||||||
@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
|||||||
Balance: grantPlan.Balance,
|
Balance: grantPlan.Balance,
|
||||||
Concurrency: grantPlan.Concurrency,
|
Concurrency: grantPlan.Concurrency,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
SignupSource: signupSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||||
|
|||||||
@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
|
|||||||
require.Empty(t, redeemRepo.updateCalls)
|
require.Empty(t, redeemRepo.updateCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
|
||||||
|
userRepo := &userRepoStub{nextID: 42}
|
||||||
|
emailCache := &emailCacheStub{
|
||||||
|
data: &VerificationCodeData{
|
||||||
|
Code: "246810",
|
||||||
|
Attempts: 0,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authService := newOAuthEmailFlowAuthService(
|
||||||
|
userRepo,
|
||||||
|
&redeemCodeRepoStub{},
|
||||||
|
&refreshTokenCacheStub{},
|
||||||
|
map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
|
},
|
||||||
|
emailCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||||
|
context.Background(),
|
||||||
|
"fresh@example.com",
|
||||||
|
"secret-123",
|
||||||
|
"246810",
|
||||||
|
"",
|
||||||
|
" OIDC ",
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, tokenPair)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Len(t, userRepo.created, 1)
|
||||||
|
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
|
||||||
|
userRepo := &userRepoStub{nextID: 43}
|
||||||
|
emailCache := &emailCacheStub{
|
||||||
|
data: &VerificationCodeData{
|
||||||
|
Code: "246810",
|
||||||
|
Attempts: 0,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authService := newOAuthEmailFlowAuthService(
|
||||||
|
userRepo,
|
||||||
|
&redeemCodeRepoStub{},
|
||||||
|
&refreshTokenCacheStub{},
|
||||||
|
map[string]string{
|
||||||
|
SettingKeyRegistrationEnabled: "true",
|
||||||
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
|
},
|
||||||
|
emailCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||||
|
context.Background(),
|
||||||
|
"fallback@example.com",
|
||||||
|
"secret-123",
|
||||||
|
"246810",
|
||||||
|
"",
|
||||||
|
"github",
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, tokenPair)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Len(t, userRepo.created, 1)
|
||||||
|
require.Equal(t, "email", userRepo.created[0].SignupSource)
|
||||||
|
}
|
||||||
|
|
||||||
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
|
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
|
||||||
userRepo := &userRepoStub{}
|
userRepo := &userRepoStub{}
|
||||||
redeemRepo := &redeemCodeRepoStub{
|
redeemRepo := &redeemCodeRepoStub{
|
||||||
|
|||||||
@ -5,10 +5,15 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||||
@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
|
|||||||
entClient *dbent.Client
|
entClient *dbent.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
|
||||||
|
|
||||||
|
type authPendingIdentityScopedKeyLockRegistry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
locks map[string]*authPendingIdentityScopedKeyLockEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type authPendingIdentityScopedKeyLockEntry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
refs int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
|
||||||
|
return &authPendingIdentityScopedKeyLockRegistry{
|
||||||
|
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
|
||||||
|
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
|
||||||
|
r.mu.Lock()
|
||||||
|
for _, key := range normalized {
|
||||||
|
entry := r.locks[key]
|
||||||
|
if entry == nil {
|
||||||
|
entry = &authPendingIdentityScopedKeyLockEntry{}
|
||||||
|
r.locks[key] = entry
|
||||||
|
}
|
||||||
|
entry.refs++
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
entry.mu.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
for i := len(entries) - 1; i >= 0; i-- {
|
||||||
|
entries[i].mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
for idx, key := range normalized {
|
||||||
|
entry := entries[idx]
|
||||||
|
entry.refs--
|
||||||
|
if entry.refs == 0 {
|
||||||
|
delete(r.locks, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deduped := make(map[string]struct{}, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
trimmed := strings.TrimSpace(key)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deduped[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(deduped) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := make([]string, 0, len(deduped))
|
||||||
|
for key := range deduped {
|
||||||
|
normalized = append(normalized, key)
|
||||||
|
}
|
||||||
|
sort.Strings(normalized)
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func authPendingIdentityAdvisoryLockHash(key string) int64 {
|
||||||
|
hasher := fnv.New64a()
|
||||||
|
_, _ = hasher.Write([]byte(key))
|
||||||
|
return int64(hasher.Sum64())
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
|
||||||
|
release := authPendingIdentityScopedKeyLocks.lock(keys...)
|
||||||
|
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
|
||||||
|
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
|
||||||
|
return release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range normalized {
|
||||||
|
var rows entsql.Rows
|
||||||
|
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
|
||||||
|
release()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = rows.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
|
||||||
|
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
|
||||||
|
if identityID != nil && *identityID > 0 {
|
||||||
|
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
|
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
|
||||||
return &AuthPendingIdentityService{entClient: entClient}
|
return &AuthPendingIdentityService{entClient: entClient}
|
||||||
}
|
}
|
||||||
@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
|||||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx, err := s.entClient.Tx(ctx)
|
||||||
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := s.entClient
|
||||||
|
txCtx := ctx
|
||||||
|
if err == nil {
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
client = tx.Client()
|
||||||
|
txCtx = dbent.NewTxContext(ctx, tx)
|
||||||
|
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||||
|
client = existingTx.Client()
|
||||||
|
}
|
||||||
|
|
||||||
|
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer releaseLocks()
|
||||||
|
|
||||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||||
if _, err := s.entClient.IdentityAdoptionDecision.Update().
|
if _, err := client.IdentityAdoptionDecision.Update().
|
||||||
Where(
|
Where(
|
||||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||||
@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
|||||||
}),
|
}),
|
||||||
).
|
).
|
||||||
ClearIdentityID().
|
ClearIdentityID().
|
||||||
Save(ctx); err != nil {
|
Save(txCtx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := s.entClient.IdentityAdoptionDecision.Query().
|
create := client.IdentityAdoptionDecision.Create().
|
||||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||||
Only(ctx)
|
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||||
if err != nil && !dbent.IsNotFound(err) {
|
SetAdoptAvatar(input.AdoptAvatar).
|
||||||
return nil, err
|
SetDecidedAt(time.Now().UTC())
|
||||||
}
|
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||||
if existing == nil {
|
create = create.SetIdentityID(*input.IdentityID)
|
||||||
create := s.entClient.IdentityAdoptionDecision.Create().
|
|
||||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
|
||||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
|
||||||
SetAdoptAvatar(input.AdoptAvatar).
|
|
||||||
SetDecidedAt(time.Now().UTC())
|
|
||||||
if input.IdentityID != nil {
|
|
||||||
create = create.SetIdentityID(*input.IdentityID)
|
|
||||||
}
|
|
||||||
return create.Save(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
|
decisionID, err := create.
|
||||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
|
||||||
SetAdoptAvatar(input.AdoptAvatar)
|
UpdateNewValues().
|
||||||
if input.IdentityID != nil {
|
ID(txCtx)
|
||||||
update = update.SetIdentityID(*input.IdentityID)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return update.Save(ctx)
|
|
||||||
|
decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if tx != nil {
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return decision, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyPendingMap(in map[string]any) map[string]any {
|
func copyPendingMap(in map[string]any) map[string]any {
|
||||||
|
|||||||
@ -5,6 +5,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
|
|||||||
require.Nil(t, reloadedFirst.IdentityID)
|
require.Nil(t, reloadedFirst.IdentityID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
|
||||||
|
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("adoption-concurrent@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(RoleUser).
|
||||||
|
SetStatus(StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetProviderType("wechat").
|
||||||
|
SetProviderKey("wechat-main").
|
||||||
|
SetProviderSubject("union-concurrent").
|
||||||
|
SetMetadata(map[string]any{}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||||
|
Intent: "bind_current_user",
|
||||||
|
Identity: PendingAuthIdentityKey{
|
||||||
|
ProviderType: "wechat",
|
||||||
|
ProviderKey: "wechat-main",
|
||||||
|
ProviderSubject: "union-concurrent",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
firstCreateStarted := make(chan struct{})
|
||||||
|
releaseFirstCreate := make(chan struct{})
|
||||||
|
var firstCreate sync.Once
|
||||||
|
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
|
||||||
|
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
|
||||||
|
blocked := false
|
||||||
|
if m.Op().Is(dbent.OpCreate) {
|
||||||
|
firstCreate.Do(func() {
|
||||||
|
blocked = true
|
||||||
|
close(firstCreateStarted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-releaseFirstCreate
|
||||||
|
}
|
||||||
|
return next.Mutate(ctx, m)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
type adoptionResult struct {
|
||||||
|
decision *dbent.IdentityAdoptionDecision
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
input := PendingIdentityAdoptionDecisionInput{
|
||||||
|
PendingAuthSessionID: session.ID,
|
||||||
|
IdentityID: &identity.ID,
|
||||||
|
AdoptDisplayName: true,
|
||||||
|
AdoptAvatar: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make(chan adoptionResult, 2)
|
||||||
|
go func() {
|
||||||
|
decision, err := svc.UpsertAdoptionDecision(ctx, input)
|
||||||
|
results <- adoptionResult{decision: decision, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-firstCreateStarted
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
decision, err := svc.UpsertAdoptionDecision(ctx, input)
|
||||||
|
results <- adoptionResult{decision: decision, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
close(releaseFirstCreate)
|
||||||
|
|
||||||
|
first := <-results
|
||||||
|
second := <-results
|
||||||
|
|
||||||
|
require.NoError(t, first.err)
|
||||||
|
require.NoError(t, second.err)
|
||||||
|
require.NotNil(t, first.decision)
|
||||||
|
require.NotNil(t, second.decision)
|
||||||
|
require.Equal(t, first.decision.ID, second.decision.ID)
|
||||||
|
|
||||||
|
count, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
loaded, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, loaded.IdentityID)
|
||||||
|
require.Equal(t, identity.ID, *loaded.IdentityID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
|
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
|
||||||
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
|
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
|||||||
Balance: grantPlan.Balance,
|
Balance: grantPlan.Balance,
|
||||||
Concurrency: grantPlan.Concurrency,
|
Concurrency: grantPlan.Concurrency,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
SignupSource: signupSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||||
@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
|||||||
Balance: grantPlan.Balance,
|
Balance: grantPlan.Balance,
|
||||||
Concurrency: grantPlan.Concurrency,
|
Concurrency: grantPlan.Concurrency,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
SignupSource: signupSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.entClient != nil && invitationRedeemCode != nil {
|
if s.entClient != nil && invitationRedeemCode != nil {
|
||||||
@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
|
|||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
Role: user.Role,
|
Role: user.Role,
|
||||||
TokenVersion: user.TokenVersion,
|
TokenVersion: resolvedTokenVersion(user),
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
|||||||
|
|
||||||
// Security: Check TokenVersion to prevent refreshing revoked tokens
|
// Security: Check TokenVersion to prevent refreshing revoked tokens
|
||||||
// This ensures tokens issued before a password change cannot be refreshed
|
// This ensures tokens issued before a password change cannot be refreshed
|
||||||
if claims.TokenVersion != user.TokenVersion {
|
if claims.TokenVersion != resolvedTokenVersion(user) {
|
||||||
return "", ErrTokenRevoked
|
return "", ErrTokenRevoked
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
|||||||
|
|
||||||
data := &RefreshTokenData{
|
data := &RefreshTokenData{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
TokenVersion: user.TokenVersion,
|
TokenVersion: resolvedTokenVersion(user),
|
||||||
FamilyID: familyID,
|
FamilyID: familyID,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
ExpiresAt: now.Add(ttl),
|
ExpiresAt: now.Add(ttl),
|
||||||
@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查TokenVersion(密码更改后所有Token失效)
|
// 检查TokenVersion(密码更改后所有Token失效)
|
||||||
if data.TokenVersion != user.TokenVersion {
|
if data.TokenVersion != resolvedTokenVersion(user) {
|
||||||
// TokenVersion不匹配,撤销整个Token家族
|
// TokenVersion不匹配,撤销整个Token家族
|
||||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||||
return nil, ErrTokenRevoked
|
return nil, ErrTokenRevoked
|
||||||
@ -1492,3 +1495,14 @@ func hashToken(token string) string {
|
|||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
return hex.EncodeToString(hash[:])
|
return hex.EncodeToString(hash[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolvedTokenVersion(user *User) int64 {
|
||||||
|
if user == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
|
||||||
|
sum := sha256.Sum256([]byte(material))
|
||||||
|
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
|
||||||
|
return user.TokenVersion ^ fingerprint
|
||||||
|
}
|
||||||
|
|||||||
@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string {
|
|||||||
return urls
|
return urls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||||
|
if base.UsePKCEExplicit {
|
||||||
|
return base.UsePKCE
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||||
|
if base.ValidateIDTokenExplicit {
|
||||||
|
return base.ValidateIDToken
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
||||||
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
|
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
|
||||||
@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
return fmt.Errorf("check existing settings: %w", err)
|
return fmt.Errorf("check existing settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oidcUsePKCEDefault := true
|
||||||
|
oidcValidateIDTokenDefault := true
|
||||||
|
if s != nil && s.cfg != nil {
|
||||||
|
if s.cfg.OIDC.UsePKCEExplicit {
|
||||||
|
oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
|
||||||
|
}
|
||||||
|
if s.cfg.OIDC.ValidateIDTokenExplicit {
|
||||||
|
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 初始化默认设置
|
// 初始化默认设置
|
||||||
defaults := map[string]string{
|
defaults := map[string]string{
|
||||||
SettingKeyRegistrationEnabled: "true",
|
SettingKeyRegistrationEnabled: "true",
|
||||||
@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
|||||||
SettingKeyOIDCConnectRedirectURL: "",
|
SettingKeyOIDCConnectRedirectURL: "",
|
||||||
SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
|
SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
|
||||||
SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
|
SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
|
||||||
SettingKeyOIDCConnectUsePKCE: "true",
|
SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
|
||||||
SettingKeyOIDCConnectValidateIDToken: "true",
|
SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
|
||||||
SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
|
SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
|
||||||
SettingKeyOIDCConnectClockSkewSeconds: "120",
|
SettingKeyOIDCConnectClockSkewSeconds: "120",
|
||||||
SettingKeyOIDCConnectRequireEmailVerified: "false",
|
SettingKeyOIDCConnectRequireEmailVerified: "false",
|
||||||
@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
|
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
|
||||||
result.OIDCConnectUsePKCE = raw == "true"
|
result.OIDCConnectUsePKCE = raw == "true"
|
||||||
} else {
|
} else {
|
||||||
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
|
result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
|
||||||
}
|
}
|
||||||
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
||||||
result.OIDCConnectValidateIDToken = raw == "true"
|
result.OIDCConnectValidateIDToken = raw == "true"
|
||||||
} else {
|
} else {
|
||||||
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
|
result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
|
||||||
}
|
}
|
||||||
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
||||||
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
|
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
|
||||||
@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
|
|||||||
}
|
}
|
||||||
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
|
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
|
||||||
effective.UsePKCE = raw == "true"
|
effective.UsePKCE = raw == "true"
|
||||||
|
} else {
|
||||||
|
effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
|
||||||
}
|
}
|
||||||
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
||||||
effective.ValidateIDToken = raw == "true"
|
effective.ValidateIDToken = raw == "true"
|
||||||
|
} else {
|
||||||
|
effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
|
||||||
}
|
}
|
||||||
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
||||||
effective.AllowedSigningAlgs = strings.TrimSpace(v)
|
effective.AllowedSigningAlgs = strings.TrimSpace(v)
|
||||||
|
|||||||
@ -118,8 +118,10 @@ func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t
|
|||||||
func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
|
func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
|
||||||
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
||||||
OIDC: config.OIDCConnectConfig{
|
OIDC: config.OIDCConnectConfig{
|
||||||
UsePKCE: true,
|
UsePKCE: true,
|
||||||
ValidateIDToken: true,
|
UsePKCEExplicit: true,
|
||||||
|
ValidateIDToken: true,
|
||||||
|
ValidateIDTokenExplicit: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
|
|||||||
require.True(t, got.OIDCConnectValidateIDToken)
|
require.True(t, got.OIDCConnectValidateIDToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
||||||
|
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
||||||
|
OIDC: config.OIDCConnectConfig{
|
||||||
|
UsePKCE: true,
|
||||||
|
ValidateIDToken: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
got := svc.parseSettings(map[string]string{
|
||||||
|
SettingKeyOIDCConnectEnabled: "true",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, got.OIDCConnectUsePKCE)
|
||||||
|
require.False(t, got.OIDCConnectValidateIDToken)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
|
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
OIDC: config.OIDCConnectConfig{
|
OIDC: config.OIDCConnectConfig{
|
||||||
@ -163,6 +181,42 @@ func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTok
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
|
func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
OIDC: config.OIDCConnectConfig{
|
||||||
|
Enabled: true,
|
||||||
|
ProviderName: "OIDC",
|
||||||
|
ClientID: "oidc-client",
|
||||||
|
ClientSecret: "oidc-secret",
|
||||||
|
IssuerURL: "https://issuer.example.com",
|
||||||
|
AuthorizeURL: "https://issuer.example.com/auth",
|
||||||
|
TokenURL: "https://issuer.example.com/token",
|
||||||
|
UserInfoURL: "https://issuer.example.com/userinfo",
|
||||||
|
JWKSURL: "https://issuer.example.com/jwks",
|
||||||
|
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
|
||||||
|
FrontendRedirectURL: "/auth/oidc/callback",
|
||||||
|
Scopes: "openid email profile",
|
||||||
|
TokenAuthMethod: "client_secret_post",
|
||||||
|
UsePKCE: true,
|
||||||
|
UsePKCEExplicit: true,
|
||||||
|
ValidateIDToken: true,
|
||||||
|
ValidateIDTokenExplicit: true,
|
||||||
|
AllowedSigningAlgs: "RS256",
|
||||||
|
ClockSkewSeconds: 120,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := &settingOIDCRepoStub{values: map[string]string{
|
||||||
|
SettingKeyOIDCConnectEnabled: "true",
|
||||||
|
}}
|
||||||
|
svc := NewSettingService(repo, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, got.UsePKCE)
|
||||||
|
require.True(t, got.ValidateIDToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
OIDC: config.OIDCConnectConfig{
|
OIDC: config.OIDCConnectConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
@ -192,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
|
|||||||
|
|
||||||
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, got.UsePKCE)
|
require.False(t, got.UsePKCE)
|
||||||
require.True(t, got.ValidateIDToken)
|
require.False(t, got.ValidateIDToken)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,23 +38,22 @@ VALUES
|
|||||||
('auth_source_default_email_balance', '0'),
|
('auth_source_default_email_balance', '0'),
|
||||||
('auth_source_default_email_concurrency', '5'),
|
('auth_source_default_email_concurrency', '5'),
|
||||||
('auth_source_default_email_subscriptions', '[]'),
|
('auth_source_default_email_subscriptions', '[]'),
|
||||||
('auth_source_default_email_grant_on_signup', 'true'),
|
('auth_source_default_email_grant_on_signup', 'false'),
|
||||||
('auth_source_default_email_grant_on_first_bind', 'false'),
|
('auth_source_default_email_grant_on_first_bind', 'false'),
|
||||||
('auth_source_default_linuxdo_balance', '0'),
|
('auth_source_default_linuxdo_balance', '0'),
|
||||||
('auth_source_default_linuxdo_concurrency', '5'),
|
('auth_source_default_linuxdo_concurrency', '5'),
|
||||||
('auth_source_default_linuxdo_subscriptions', '[]'),
|
('auth_source_default_linuxdo_subscriptions', '[]'),
|
||||||
('auth_source_default_linuxdo_grant_on_signup', 'true'),
|
('auth_source_default_linuxdo_grant_on_signup', 'false'),
|
||||||
('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
|
('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
|
||||||
('auth_source_default_oidc_balance', '0'),
|
('auth_source_default_oidc_balance', '0'),
|
||||||
('auth_source_default_oidc_concurrency', '5'),
|
('auth_source_default_oidc_concurrency', '5'),
|
||||||
('auth_source_default_oidc_subscriptions', '[]'),
|
('auth_source_default_oidc_subscriptions', '[]'),
|
||||||
('auth_source_default_oidc_grant_on_signup', 'true'),
|
('auth_source_default_oidc_grant_on_signup', 'false'),
|
||||||
('auth_source_default_oidc_grant_on_first_bind', 'false'),
|
('auth_source_default_oidc_grant_on_first_bind', 'false'),
|
||||||
('auth_source_default_wechat_balance', '0'),
|
('auth_source_default_wechat_balance', '0'),
|
||||||
('auth_source_default_wechat_concurrency', '5'),
|
('auth_source_default_wechat_concurrency', '5'),
|
||||||
('auth_source_default_wechat_subscriptions', '[]'),
|
('auth_source_default_wechat_subscriptions', '[]'),
|
||||||
('auth_source_default_wechat_grant_on_signup', 'true'),
|
('auth_source_default_wechat_grant_on_signup', 'false'),
|
||||||
('auth_source_default_wechat_grant_on_first_bind', 'false'),
|
('auth_source_default_wechat_grant_on_first_bind', 'false'),
|
||||||
('force_email_on_third_party_signup', 'false')
|
('force_email_on_third_party_signup', 'false')
|
||||||
ON CONFLICT (key) DO NOTHING;
|
ON CONFLICT (key) DO NOTHING;
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,41 @@ BEGIN
|
|||||||
END IF;
|
END IF;
|
||||||
|
|
||||||
EXECUTE $sql$
|
EXECUTE $sql$
|
||||||
|
WITH legacy AS (
|
||||||
|
SELECT
|
||||||
|
uei.id,
|
||||||
|
uei.user_id,
|
||||||
|
BTRIM(uei.provider_user_id) AS provider_user_id,
|
||||||
|
BTRIM(uei.provider_username) AS provider_username,
|
||||||
|
BTRIM(uei.display_name) AS display_name,
|
||||||
|
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
|
||||||
|
uei.created_at,
|
||||||
|
uei.updated_at
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
|
||||||
|
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
||||||
|
),
|
||||||
|
legacy_subjects AS (
|
||||||
|
SELECT
|
||||||
|
provider_user_id AS provider_subject,
|
||||||
|
COUNT(DISTINCT user_id) AS distinct_user_count
|
||||||
|
FROM legacy
|
||||||
|
GROUP BY provider_user_id
|
||||||
|
),
|
||||||
|
canonical_legacy AS (
|
||||||
|
SELECT
|
||||||
|
legacy.*,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY legacy.provider_user_id
|
||||||
|
ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
|
||||||
|
) AS canonical_row_num
|
||||||
|
FROM legacy
|
||||||
|
JOIN legacy_subjects AS subjects
|
||||||
|
ON subjects.provider_subject = legacy.provider_user_id
|
||||||
|
AND subjects.distinct_user_count = 1
|
||||||
|
)
|
||||||
INSERT INTO auth_identities (
|
INSERT INTO auth_identities (
|
||||||
user_id,
|
user_id,
|
||||||
provider_type,
|
provider_type,
|
||||||
@ -52,11 +87,18 @@ SELECT
|
|||||||
'display_name', legacy.display_name,
|
'display_name', legacy.display_name,
|
||||||
'migration', '115_auth_identity_legacy_external_backfill'
|
'migration', '115_auth_identity_legacy_external_backfill'
|
||||||
)
|
)
|
||||||
FROM (
|
FROM canonical_legacy AS legacy
|
||||||
|
WHERE legacy.canonical_row_num = 1
|
||||||
|
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
|
||||||
|
$sql$;
|
||||||
|
|
||||||
|
EXECUTE $sql$
|
||||||
|
WITH legacy AS (
|
||||||
SELECT
|
SELECT
|
||||||
uei.id,
|
uei.id,
|
||||||
uei.user_id,
|
uei.user_id,
|
||||||
BTRIM(uei.provider_user_id) AS provider_user_id,
|
BTRIM(uei.provider_user_id) AS provider_user_id,
|
||||||
|
BTRIM(uei.provider_union_id) AS provider_union_id,
|
||||||
BTRIM(uei.provider_username) AS provider_username,
|
BTRIM(uei.provider_username) AS provider_username,
|
||||||
BTRIM(uei.display_name) AS display_name,
|
BTRIM(uei.display_name) AS display_name,
|
||||||
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
|
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
|
||||||
@ -65,13 +107,28 @@ FROM (
|
|||||||
FROM user_external_identities AS uei
|
FROM user_external_identities AS uei
|
||||||
JOIN users AS u ON u.id = uei.user_id
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
WHERE u.deleted_at IS NULL
|
WHERE u.deleted_at IS NULL
|
||||||
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
||||||
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
||||||
) AS legacy
|
),
|
||||||
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
|
legacy_subjects AS (
|
||||||
$sql$;
|
SELECT
|
||||||
|
provider_union_id AS provider_subject,
|
||||||
EXECUTE $sql$
|
COUNT(DISTINCT user_id) AS distinct_user_count
|
||||||
|
FROM legacy
|
||||||
|
GROUP BY provider_union_id
|
||||||
|
),
|
||||||
|
canonical_legacy AS (
|
||||||
|
SELECT
|
||||||
|
legacy.*,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY legacy.provider_union_id
|
||||||
|
ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
|
||||||
|
) AS canonical_row_num
|
||||||
|
FROM legacy
|
||||||
|
JOIN legacy_subjects AS subjects
|
||||||
|
ON subjects.provider_subject = legacy.provider_union_id
|
||||||
|
AND subjects.distinct_user_count = 1
|
||||||
|
)
|
||||||
INSERT INTO auth_identities (
|
INSERT INTO auth_identities (
|
||||||
user_id,
|
user_id,
|
||||||
provider_type,
|
provider_type,
|
||||||
@ -96,27 +153,36 @@ SELECT
|
|||||||
'display_name', legacy.display_name,
|
'display_name', legacy.display_name,
|
||||||
'migration', '115_auth_identity_legacy_external_backfill'
|
'migration', '115_auth_identity_legacy_external_backfill'
|
||||||
)
|
)
|
||||||
FROM (
|
FROM canonical_legacy AS legacy
|
||||||
SELECT
|
WHERE legacy.canonical_row_num = 1
|
||||||
uei.id,
|
|
||||||
uei.user_id,
|
|
||||||
BTRIM(uei.provider_user_id) AS provider_user_id,
|
|
||||||
BTRIM(uei.provider_union_id) AS provider_union_id,
|
|
||||||
BTRIM(uei.provider_username) AS provider_username,
|
|
||||||
BTRIM(uei.display_name) AS display_name,
|
|
||||||
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
|
|
||||||
uei.created_at,
|
|
||||||
uei.updated_at
|
|
||||||
FROM user_external_identities AS uei
|
|
||||||
JOIN users AS u ON u.id = uei.user_id
|
|
||||||
WHERE u.deleted_at IS NULL
|
|
||||||
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
|
||||||
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
|
||||||
) AS legacy
|
|
||||||
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
|
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
|
||||||
$sql$;
|
$sql$;
|
||||||
|
|
||||||
EXECUTE $sql$
|
EXECUTE $sql$
|
||||||
|
WITH legacy AS (
|
||||||
|
SELECT
|
||||||
|
uei.user_id,
|
||||||
|
BTRIM(uei.provider_user_id) AS provider_user_id,
|
||||||
|
BTRIM(uei.provider_union_id) AS provider_union_id,
|
||||||
|
BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
|
||||||
|
BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
|
||||||
|
meta.metadata_json
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
CROSS JOIN LATERAL (
|
||||||
|
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
|
||||||
|
) AS meta
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
||||||
|
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
||||||
|
),
|
||||||
|
legacy_subjects AS (
|
||||||
|
SELECT
|
||||||
|
provider_union_id AS provider_subject,
|
||||||
|
COUNT(DISTINCT user_id) AS distinct_user_count
|
||||||
|
FROM legacy
|
||||||
|
GROUP BY provider_union_id
|
||||||
|
)
|
||||||
INSERT INTO auth_identity_channels (
|
INSERT INTO auth_identity_channels (
|
||||||
identity_id,
|
identity_id,
|
||||||
provider_type,
|
provider_type,
|
||||||
@ -138,23 +204,10 @@ SELECT
|
|||||||
'unionid', legacy.provider_union_id,
|
'unionid', legacy.provider_union_id,
|
||||||
'migration', '115_auth_identity_legacy_external_backfill'
|
'migration', '115_auth_identity_legacy_external_backfill'
|
||||||
)
|
)
|
||||||
FROM (
|
FROM legacy
|
||||||
SELECT
|
JOIN legacy_subjects AS subjects
|
||||||
uei.user_id,
|
ON subjects.provider_subject = legacy.provider_union_id
|
||||||
BTRIM(uei.provider_user_id) AS provider_user_id,
|
AND subjects.distinct_user_count = 1
|
||||||
BTRIM(uei.provider_union_id) AS provider_union_id,
|
|
||||||
BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
|
|
||||||
BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
|
|
||||||
meta.metadata_json
|
|
||||||
FROM user_external_identities AS uei
|
|
||||||
JOIN users AS u ON u.id = uei.user_id
|
|
||||||
CROSS JOIN LATERAL (
|
|
||||||
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
|
|
||||||
) AS meta
|
|
||||||
WHERE u.deleted_at IS NULL
|
|
||||||
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
|
||||||
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
|
||||||
) AS legacy
|
|
||||||
JOIN auth_identities AS ai
|
JOIN auth_identities AS ai
|
||||||
ON ai.user_id = legacy.user_id
|
ON ai.user_id = legacy.user_id
|
||||||
AND ai.provider_type = 'wechat'
|
AND ai.provider_type = 'wechat'
|
||||||
|
|||||||
@ -74,6 +74,82 @@ $sql$;
|
|||||||
|
|
||||||
EXECUTE $sql$
|
EXECUTE $sql$
|
||||||
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
|
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
|
||||||
|
SELECT
|
||||||
|
'legacy_external_identity_conflict',
|
||||||
|
'legacy_external_identity:' || legacy.id::text,
|
||||||
|
legacy.metadata_json || jsonb_build_object(
|
||||||
|
'legacy_identity_id', legacy.id,
|
||||||
|
'legacy_user_id', legacy.user_id,
|
||||||
|
'provider_type', legacy.provider_type,
|
||||||
|
'provider_key', legacy.provider_key,
|
||||||
|
'provider_subject', legacy.provider_subject,
|
||||||
|
'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids,
|
||||||
|
'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved',
|
||||||
|
'migration', '116_auth_identity_legacy_external_safety_reports'
|
||||||
|
)
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
uei.id,
|
||||||
|
uei.user_id,
|
||||||
|
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
|
||||||
|
CASE
|
||||||
|
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
|
||||||
|
ELSE 'linuxdo'
|
||||||
|
END AS provider_key,
|
||||||
|
CASE
|
||||||
|
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
|
||||||
|
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
|
||||||
|
END AS provider_subject,
|
||||||
|
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
|
||||||
|
AND (
|
||||||
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
|
||||||
|
OR
|
||||||
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
||||||
|
)
|
||||||
|
) AS legacy
|
||||||
|
JOIN (
|
||||||
|
SELECT
|
||||||
|
provider_type,
|
||||||
|
provider_key,
|
||||||
|
provider_subject,
|
||||||
|
to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
uei.user_id,
|
||||||
|
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
|
||||||
|
CASE
|
||||||
|
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
|
||||||
|
ELSE 'linuxdo'
|
||||||
|
END AS provider_key,
|
||||||
|
CASE
|
||||||
|
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
|
||||||
|
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
|
||||||
|
END AS provider_subject
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
|
||||||
|
AND (
|
||||||
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
|
||||||
|
OR
|
||||||
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
||||||
|
)
|
||||||
|
) AS legacy_subjects
|
||||||
|
GROUP BY provider_type, provider_key, provider_subject
|
||||||
|
HAVING COUNT(DISTINCT user_id) > 1
|
||||||
|
) AS ambiguous
|
||||||
|
ON ambiguous.provider_type = legacy.provider_type
|
||||||
|
AND ambiguous.provider_key = legacy.provider_key
|
||||||
|
AND ambiguous.provider_subject = legacy.provider_subject
|
||||||
|
ON CONFLICT (report_type, report_key) DO NOTHING;
|
||||||
|
$sql$;
|
||||||
|
|
||||||
|
EXECUTE $sql$
|
||||||
|
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
|
||||||
SELECT
|
SELECT
|
||||||
'legacy_external_identity_conflict',
|
'legacy_external_identity_conflict',
|
||||||
'legacy_external_identity:' || legacy.id::text,
|
'legacy_external_identity:' || legacy.id::text,
|
||||||
@ -116,6 +192,39 @@ FROM (
|
|||||||
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
||||||
)
|
)
|
||||||
) AS legacy
|
) AS legacy
|
||||||
|
JOIN (
|
||||||
|
SELECT
|
||||||
|
provider_type,
|
||||||
|
provider_key,
|
||||||
|
provider_subject
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
uei.user_id,
|
||||||
|
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
|
||||||
|
CASE
|
||||||
|
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
|
||||||
|
ELSE 'linuxdo'
|
||||||
|
END AS provider_key,
|
||||||
|
CASE
|
||||||
|
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
|
||||||
|
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
|
||||||
|
END AS provider_subject
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
|
||||||
|
AND (
|
||||||
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
|
||||||
|
OR
|
||||||
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
||||||
|
)
|
||||||
|
) AS legacy_subjects
|
||||||
|
GROUP BY provider_type, provider_key, provider_subject
|
||||||
|
HAVING COUNT(DISTINCT user_id) = 1
|
||||||
|
) AS clear_subjects
|
||||||
|
ON clear_subjects.provider_type = legacy.provider_type
|
||||||
|
AND clear_subjects.provider_key = legacy.provider_key
|
||||||
|
AND clear_subjects.provider_subject = legacy.provider_subject
|
||||||
JOIN auth_identities AS ai
|
JOIN auth_identities AS ai
|
||||||
ON ai.provider_type = legacy.provider_type
|
ON ai.provider_type = legacy.provider_type
|
||||||
AND ai.provider_key = legacy.provider_key
|
AND ai.provider_key = legacy.provider_key
|
||||||
@ -125,29 +234,7 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
|
|||||||
$sql$;
|
$sql$;
|
||||||
|
|
||||||
EXECUTE $sql$
|
EXECUTE $sql$
|
||||||
INSERT INTO auth_identities (
|
WITH legacy AS (
|
||||||
user_id,
|
|
||||||
provider_type,
|
|
||||||
provider_key,
|
|
||||||
provider_subject,
|
|
||||||
verified_at,
|
|
||||||
metadata
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
legacy.user_id,
|
|
||||||
legacy.provider_type,
|
|
||||||
legacy.provider_key,
|
|
||||||
legacy.provider_subject,
|
|
||||||
legacy.verified_at,
|
|
||||||
legacy.metadata_json || jsonb_build_object(
|
|
||||||
'legacy_identity_id', legacy.id,
|
|
||||||
'provider_user_id', legacy.provider_user_id,
|
|
||||||
'provider_union_id', NULLIF(legacy.provider_union_id, ''),
|
|
||||||
'provider_username', legacy.provider_username,
|
|
||||||
'display_name', legacy.display_name,
|
|
||||||
'migration', '116_auth_identity_legacy_external_safety_reports'
|
|
||||||
)
|
|
||||||
FROM (
|
|
||||||
SELECT
|
SELECT
|
||||||
uei.id,
|
uei.id,
|
||||||
uei.user_id,
|
uei.user_id,
|
||||||
@ -175,12 +262,58 @@ FROM (
|
|||||||
OR
|
OR
|
||||||
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
|
||||||
)
|
)
|
||||||
) AS legacy
|
),
|
||||||
|
clear_subjects AS (
|
||||||
|
SELECT
|
||||||
|
provider_type,
|
||||||
|
provider_key,
|
||||||
|
provider_subject
|
||||||
|
FROM legacy
|
||||||
|
GROUP BY provider_type, provider_key, provider_subject
|
||||||
|
HAVING COUNT(DISTINCT user_id) = 1
|
||||||
|
),
|
||||||
|
canonical_legacy AS (
|
||||||
|
SELECT
|
||||||
|
legacy.*,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject
|
||||||
|
ORDER BY legacy.verified_at DESC, legacy.id DESC
|
||||||
|
) AS canonical_row_num
|
||||||
|
FROM legacy
|
||||||
|
JOIN clear_subjects
|
||||||
|
ON clear_subjects.provider_type = legacy.provider_type
|
||||||
|
AND clear_subjects.provider_key = legacy.provider_key
|
||||||
|
AND clear_subjects.provider_subject = legacy.provider_subject
|
||||||
|
)
|
||||||
|
INSERT INTO auth_identities (
|
||||||
|
user_id,
|
||||||
|
provider_type,
|
||||||
|
provider_key,
|
||||||
|
provider_subject,
|
||||||
|
verified_at,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
legacy.user_id,
|
||||||
|
legacy.provider_type,
|
||||||
|
legacy.provider_key,
|
||||||
|
legacy.provider_subject,
|
||||||
|
legacy.verified_at,
|
||||||
|
legacy.metadata_json || jsonb_build_object(
|
||||||
|
'legacy_identity_id', legacy.id,
|
||||||
|
'provider_user_id', legacy.provider_user_id,
|
||||||
|
'provider_union_id', NULLIF(legacy.provider_union_id, ''),
|
||||||
|
'provider_username', legacy.provider_username,
|
||||||
|
'display_name', legacy.display_name,
|
||||||
|
'migration', '116_auth_identity_legacy_external_safety_reports'
|
||||||
|
)
|
||||||
|
FROM canonical_legacy AS legacy
|
||||||
LEFT JOIN auth_identities AS ai
|
LEFT JOIN auth_identities AS ai
|
||||||
ON ai.provider_type = legacy.provider_type
|
ON ai.provider_type = legacy.provider_type
|
||||||
AND ai.provider_key = legacy.provider_key
|
AND ai.provider_key = legacy.provider_key
|
||||||
AND ai.provider_subject = legacy.provider_subject
|
AND ai.provider_subject = legacy.provider_subject
|
||||||
WHERE ai.id IS NULL
|
WHERE legacy.canonical_row_num = 1
|
||||||
|
AND ai.id IS NULL
|
||||||
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
|
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
|
||||||
$sql$;
|
$sql$;
|
||||||
|
|
||||||
@ -225,6 +358,19 @@ FROM (
|
|||||||
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
||||||
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
||||||
) AS legacy
|
) AS legacy
|
||||||
|
JOIN (
|
||||||
|
SELECT
|
||||||
|
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
||||||
|
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
||||||
|
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
||||||
|
GROUP BY BTRIM(COALESCE(uei.provider_union_id, ''))
|
||||||
|
HAVING COUNT(DISTINCT uei.user_id) = 1
|
||||||
|
) AS clear_subjects
|
||||||
|
ON clear_subjects.provider_subject = legacy.provider_union_id
|
||||||
JOIN auth_identities AS legacy_ai
|
JOIN auth_identities AS legacy_ai
|
||||||
ON legacy_ai.user_id = legacy.user_id
|
ON legacy_ai.user_id = legacy.user_id
|
||||||
AND legacy_ai.provider_type = 'wechat'
|
AND legacy_ai.provider_type = 'wechat'
|
||||||
@ -245,6 +391,33 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
|
|||||||
$sql$;
|
$sql$;
|
||||||
|
|
||||||
EXECUTE $sql$
|
EXECUTE $sql$
|
||||||
|
WITH legacy AS (
|
||||||
|
SELECT
|
||||||
|
uei.user_id,
|
||||||
|
BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
|
||||||
|
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
|
||||||
|
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
|
||||||
|
BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
|
||||||
|
BTRIM(COALESCE(
|
||||||
|
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
|
||||||
|
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
|
||||||
|
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
|
||||||
|
''
|
||||||
|
)) AS channel_app_id
|
||||||
|
FROM user_external_identities AS uei
|
||||||
|
JOIN users AS u ON u.id = uei.user_id
|
||||||
|
WHERE u.deleted_at IS NULL
|
||||||
|
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
||||||
|
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
||||||
|
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
||||||
|
),
|
||||||
|
clear_subjects AS (
|
||||||
|
SELECT
|
||||||
|
provider_union_id AS provider_subject
|
||||||
|
FROM legacy
|
||||||
|
GROUP BY provider_union_id
|
||||||
|
HAVING COUNT(DISTINCT user_id) = 1
|
||||||
|
)
|
||||||
INSERT INTO auth_identity_channels (
|
INSERT INTO auth_identity_channels (
|
||||||
identity_id,
|
identity_id,
|
||||||
provider_type,
|
provider_type,
|
||||||
@ -266,26 +439,9 @@ SELECT
|
|||||||
'unionid', legacy.provider_union_id,
|
'unionid', legacy.provider_union_id,
|
||||||
'migration', '116_auth_identity_legacy_external_safety_reports'
|
'migration', '116_auth_identity_legacy_external_safety_reports'
|
||||||
)
|
)
|
||||||
FROM (
|
FROM legacy
|
||||||
SELECT
|
JOIN clear_subjects
|
||||||
uei.user_id,
|
ON clear_subjects.provider_subject = legacy.provider_union_id
|
||||||
BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
|
|
||||||
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
|
|
||||||
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
|
|
||||||
BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
|
|
||||||
BTRIM(COALESCE(
|
|
||||||
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
|
|
||||||
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
|
|
||||||
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
|
|
||||||
''
|
|
||||||
)) AS channel_app_id
|
|
||||||
FROM user_external_identities AS uei
|
|
||||||
JOIN users AS u ON u.id = uei.user_id
|
|
||||||
WHERE u.deleted_at IS NULL
|
|
||||||
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
|
|
||||||
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
|
|
||||||
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
|
|
||||||
) AS legacy
|
|
||||||
JOIN auth_identities AS legacy_ai
|
JOIN auth_identities AS legacy_ai
|
||||||
ON legacy_ai.user_id = legacy.user_id
|
ON legacy_ai.user_id = legacy.user_id
|
||||||
AND legacy_ai.provider_type = 'wechat'
|
AND legacy_ai.provider_type = 'wechat'
|
||||||
|
|||||||
@ -1,3 +1,68 @@
|
|||||||
-- Intentionally left as a no-op.
|
-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value.
|
||||||
-- Legacy installs may have intentionally kept the original signup grant defaults,
|
-- Rows still matching the migration-110 default payload and timestamp window are treated as
|
||||||
-- and we cannot distinguish those cases safely from untouched migration 110 rows.
|
-- untouched legacy defaults; any remaining legacy true values are reported for manual review.
|
||||||
|
|
||||||
|
WITH migration_110 AS (
|
||||||
|
SELECT applied_at
|
||||||
|
FROM schema_migrations
|
||||||
|
WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
|
||||||
|
),
|
||||||
|
providers AS (
|
||||||
|
SELECT provider_type
|
||||||
|
FROM (
|
||||||
|
VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
|
||||||
|
) AS providers(provider_type)
|
||||||
|
),
|
||||||
|
legacy_provider_defaults AS (
|
||||||
|
SELECT providers.provider_type
|
||||||
|
FROM providers
|
||||||
|
CROSS JOIN migration_110
|
||||||
|
JOIN settings balance
|
||||||
|
ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
|
||||||
|
JOIN settings concurrency
|
||||||
|
ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
|
||||||
|
JOIN settings subscriptions
|
||||||
|
ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
|
||||||
|
JOIN settings grant_on_signup
|
||||||
|
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
|
||||||
|
JOIN settings grant_on_first_bind
|
||||||
|
ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
|
||||||
|
WHERE balance.value = '0'
|
||||||
|
AND concurrency.value = '5'
|
||||||
|
AND subscriptions.value = '[]'
|
||||||
|
AND grant_on_signup.value = 'true'
|
||||||
|
AND grant_on_first_bind.value = 'false'
|
||||||
|
AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
),
|
||||||
|
updated_signup_grants AS (
|
||||||
|
UPDATE settings
|
||||||
|
SET
|
||||||
|
value = 'false',
|
||||||
|
updated_at = NOW()
|
||||||
|
FROM legacy_provider_defaults
|
||||||
|
WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
|
||||||
|
AND settings.value = 'true'
|
||||||
|
RETURNING legacy_provider_defaults.provider_type
|
||||||
|
)
|
||||||
|
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
|
||||||
|
SELECT
|
||||||
|
'legacy_auth_source_signup_grant_review',
|
||||||
|
providers.provider_type,
|
||||||
|
jsonb_build_object(
|
||||||
|
'provider_type', providers.provider_type,
|
||||||
|
'current_value', grant_on_signup.value,
|
||||||
|
'auto_backfilled', FALSE,
|
||||||
|
'reason', 'legacy_true_default_not_auto_backfilled'
|
||||||
|
)
|
||||||
|
FROM providers
|
||||||
|
JOIN settings grant_on_signup
|
||||||
|
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
|
||||||
|
LEFT JOIN updated_signup_grants
|
||||||
|
ON updated_signup_grants.provider_type = providers.provider_type
|
||||||
|
WHERE grant_on_signup.value = 'true'
|
||||||
|
AND updated_signup_grants.provider_type IS NULL
|
||||||
|
ON CONFLICT (report_type, report_key) DO NOTHING;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user