Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation

fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
This commit is contained in:
Wesley Liddick 2026-04-22 18:18:39 +08:00 committed by GitHub
commit ddf80f5ea1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
140 changed files with 11053 additions and 1202 deletions

View File

@ -28,6 +28,26 @@ jobs:
working-directory: backend
run: make test-integration
frontend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '20'
cache: 'pnpm'
cache-dependency-path: frontend/pnpm-lock.yaml
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Frontend typecheck and critical vitest
run: make test-frontend
golangci-lint:
runs-on: ubuntu-latest
steps:
@ -46,4 +66,4 @@ jobs:
with:
version: v2.9
args: --timeout=30m
working-directory: backend
working-directory: backend

View File

@ -1,4 +1,12 @@
.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan
FRONTEND_CRITICAL_VITEST := \
src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \
src/views/auth/__tests__/WechatCallbackView.spec.ts \
src/views/user/__tests__/PaymentView.spec.ts \
src/views/user/__tests__/PaymentResultView.spec.ts \
src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
src/views/admin/__tests__/SettingsView.spec.ts
# 一键编译前后端
build: build-backend build-frontend
@ -24,6 +32,10 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
@$(MAKE) test-frontend-critical
test-frontend-critical:
@pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST)
test-datamanagementd:
@cd datamanagement && go test ./...

View File

@ -42,10 +42,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Smart Scheduling** - Intelligent account selection with sticky sessions
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Configuration Guide](docs/PAYMENT.md))
- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Payment Setup](#payment))
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. ticketing) via iframe to extend the admin dashboard
## Payment
Sub2API includes the payment system in the main service. No standalone payment service or separate payment guide is required.
- Supported providers: EasyPay, Alipay, WeChat Pay, Stripe
- The frontend keeps user-facing methods unified; admins choose the backing source in `Admin -> Settings -> Payment`
- Callback URLs are generated from the site domain when configuring providers
## ❤️ Sponsors
> [Want to appear here?](mailto:support@pincc.ai)
@ -109,7 +117,7 @@ Community projects that extend or integrate with Sub2API:
| Project | Description | Features |
|---------|-------------|----------|
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Configuration Guide](docs/PAYMENT.md) |
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Setup](#payment) |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
## Tech Stack

View File

@ -41,10 +41,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
- **智能调度** - 智能账号选择,支持粘性会话
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe用户自助充值无需独立部署支付服务[配置指南](docs/PAYMENT_CN.md)
- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe用户自助充值无需独立部署支付服务[支付说明](#支付)
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如工单等),扩展管理后台功能
## 支付
Sub2API 已将支付系统集成到主服务中,无需独立支付服务,也不再依赖单独的支付配置文档。
- 支持服务商EasyPay 易支付、支付宝官方、微信官方、Stripe
- 前台统一展示用户可见支付方式,管理员在 `管理后台 -> 设置 -> 支付` 里选择对应来源
- 添加服务商时会基于站点域名生成回调地址
## ❤️ 赞助商
> [想出现在这里?](mailto:support@pincc.ai)
@ -108,7 +116,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
| 项目 | 说明 | 功能 |
|------|------|------|
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付配置指南](docs/PAYMENT_CN.md) |
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付说明](#支付) |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用iOS/Android/Web支持用户管理、账号管理、监控看板、多后端切换基于 Expo + React Native 构建 |
## 技术栈

View File

@ -42,10 +42,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
- **スマートスケジューリング** - スティッキーセッション付きのインテリジェントなアカウント選択
- **同時実行制御** - ユーザーごと・アカウントごとの同時実行数制限
- **レート制限** - 設定可能なリクエスト数およびトークンレート制限
- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([設定ガイド](docs/PAYMENT.md)
- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([決済案内](#決済)
- **管理ダッシュボード** - 監視・管理のための Web インターフェース
- **外部システム連携** - 外部システム(チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能
## 決済
Sub2API の決済機能は本体に統合されています。独立した決済サービスや別個の決済ガイドは不要です。
- 対応プロバイダー: EasyPay、Alipay、WeChat Pay、Stripe
- フロントエンドではユーザー向け決済方法を統一表示し、管理者は `管理画面 -> 設定 -> 決済` で実際の接続先を選択します
- プロバイダー設定時のコールバック URL はサイトドメインから自動生成されます
## ❤️ スポンサー
> [こちらに掲載しませんか?](mailto:support@pincc.ai)
@ -108,7 +116,7 @@ Sub2API を拡張・統合するコミュニティプロジェクト:
| プロジェクト | 説明 | 機能 |
|---------|-------------|----------|
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済設定ガイド](docs/PAYMENT.md)をご参照ください |
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済案内](#決済)をご参照ください |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | モバイル管理コンソール | ユーザー管理、アカウント管理、監視ダッシュボード、マルチバックエンド切り替えが可能なクロスプラットフォームアプリiOS/Android/Web。Expo + React Native で構築 |
## 技術スタック

View File

@ -0,0 +1,73 @@
package migrate
import (
"testing"
"entgo.io/ent/dialect/entsql"
entschema "entgo.io/ent/dialect/sql/schema"
"github.com/stretchr/testify/require"
)
func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) {
require.Equal(
t,
entschema.Cascade,
findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete,
)
require.Equal(
t,
entschema.Cascade,
findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete,
)
require.Equal(
t,
entschema.Cascade,
findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete,
)
require.Equal(
t,
entschema.SetNull,
findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete,
)
require.Equal(
t,
entschema.SetNull,
findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete,
)
}
func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) {
idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no")
require.True(t, idx.Unique)
require.Len(t, idx.Columns, 1)
require.Equal(t, "out_trade_no", idx.Columns[0].Name)
require.NotNil(t, idx.Annotation)
require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where)
}
func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey {
t.Helper()
for _, fk := range table.ForeignKeys {
if fk.Symbol == symbol {
return fk
}
}
require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol)
return nil
}
func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index {
t.Helper()
for _, idx := range table.Indexes {
if idx.Name == name {
return idx
}
}
require.Failf(t, "missing index", "table %s should include index %s", table.Name, name)
return nil
}

View File

@ -361,7 +361,7 @@ var (
Symbol: "auth_identities_users_auth_identities",
Columns: []*schema.Column{AuthIdentitiesColumns[9]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
OnDelete: schema.Cascade,
},
},
Indexes: []*schema.Index{
@ -405,7 +405,7 @@ var (
Symbol: "auth_identity_channels_auth_identities_channels",
Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
OnDelete: schema.NoAction,
OnDelete: schema.Cascade,
},
},
Indexes: []*schema.Index{
@ -595,7 +595,7 @@ var (
Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
OnDelete: schema.NoAction,
OnDelete: schema.Cascade,
},
},
Indexes: []*schema.Index{
@ -692,8 +692,11 @@ var (
Indexes: []*schema.Index{
{
Name: "paymentorder_out_trade_no",
Unique: false,
Unique: true,
Columns: []*schema.Column{PaymentOrdersColumns[8]},
Annotation: &entsql.IndexAnnotation{
Where: "out_trade_no <> ''",
},
},
{
Name: "paymentorder_user_id",

View File

@ -79,7 +79,8 @@ func (AuthIdentity) Edges() []ent.Edge {
Field("user_id").
Required().
Unique(),
edge.To("channels", AuthIdentityChannel.Type),
edge.To("channels", AuthIdentityChannel.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
}
}

View File

@ -3,7 +3,9 @@ package schema
import (
"testing"
"entgo.io/ent"
"entgo.io/ent/entc/load"
"entgo.io/ent/schema/field"
"github.com/stretchr/testify/require"
)
@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
userSchema := requireSchema(t, schemas, "User")
requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
signupSource := requireSchemaField(t, userSchema, "signup_source")
require.Equal(t, field.TypeString, signupSource.Info.Type)
require.True(t, signupSource.Default)
require.Equal(t, "email", signupSource.DefaultValue)
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("github"))
}
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
}
}
func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
t.Helper()
for _, schemaField := range schema.Fields {
if schemaField.Name == name {
return schemaField
}
}
require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
return nil
}
func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
t.Helper()
for _, entField := range fields {
descriptor := entField.Descriptor()
if descriptor.Name != name {
continue
}
require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
validator, ok := descriptor.Validators[0].(func(string) error)
require.True(t, ok, "field %s validator should be func(string) error", name)
return validator
}
require.Failf(t, "missing field validator", "schema should include field %s", name)
return nil
}
func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
t.Helper()

View File

@ -185,7 +185,9 @@ func (PaymentOrder) Edges() []ent.Edge {
func (PaymentOrder) Indexes() []ent.Index {
return []ent.Index{
index.Fields("out_trade_no"),
index.Fields("out_trade_no").
Unique().
Annotations(entsql.IndexWhere("out_trade_no <> ''")),
index.Fields("user_id"),
index.Fields("status"),
index.Fields("expires_at"),

View File

@ -119,6 +119,7 @@ func (PendingAuthSession) Edges() []ent.Edge {
Field("target_user_id").
Unique(),
edge.To("adoption_decision", IdentityAdoptionDecision.Type).
Annotations(entsql.OnDelete(entsql.Cascade)).
Unique(),
}
}

View File

@ -1,6 +1,8 @@
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
@ -73,7 +75,14 @@ func (User) Fields() []ent.Field {
Optional().
Nillable(),
field.String("signup_source").
MaxLen(20).
Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
}
}).
Default("email"),
field.Time("last_login_at").
Optional().
@ -115,7 +124,8 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
edge.To("auth_identities", AuthIdentity.Type),
edge.To("auth_identities", AuthIdentity.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}

View File

@ -70,6 +70,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@ -190,26 +191,47 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type WeChatConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
OpenAppID string `mapstructure:"open_app_id"`
OpenAppSecret string `mapstructure:"open_app_secret"`
MPAppID string `mapstructure:"mp_app_id"`
MPAppSecret string `mapstructure:"mp_app_secret"`
MobileAppID string `mapstructure:"mobile_app_id"`
MobileAppSecret string `mapstructure:"mobile_app_secret"`
OpenEnabled bool `mapstructure:"open_enabled"`
MPEnabled bool `mapstructure:"mp_enabled"`
MobileEnabled bool `mapstructure:"mobile_enabled"`
Mode string `mapstructure:"mode"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
}
type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
@ -218,6 +240,225 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback"
)
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func normalizeWeChatConnectMode(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "mp":
return "mp"
case "mobile":
return "mobile"
default:
return defaultWeChatConnectMode
}
}
func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
mode = normalizeWeChatConnectMode(mode)
switch mode {
case "open":
if openEnabled {
return "open"
}
case "mp":
if mpEnabled {
return "mp"
}
case "mobile":
if mobileEnabled {
return "mobile"
}
}
switch {
case openEnabled:
return "open"
case mpEnabled:
return "mp"
case mobileEnabled:
return "mobile"
default:
return mode
}
}
func defaultWeChatConnectScopesForMode(mode string) string {
switch normalizeWeChatConnectMode(mode) {
case "mp":
return "snsapi_userinfo"
case "mobile":
return ""
default:
return defaultWeChatConnectScopes
}
}
func normalizeWeChatConnectScopes(raw, mode string) string {
switch normalizeWeChatConnectMode(mode) {
case "mp":
switch strings.TrimSpace(raw) {
case "snsapi_base":
return "snsapi_base"
case "snsapi_userinfo":
return "snsapi_userinfo"
default:
return defaultWeChatConnectScopesForMode(mode)
}
case "mobile":
return ""
default:
return defaultWeChatConnectScopes
}
}
func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return false
}
_, hasNewEnv := os.LookupEnv(envKey)
return !hasNewEnv
}
func hasExplicitConfigOrEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return true
}
_, ok := os.LookupEnv(envKey)
return ok
}
func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
if cfg == nil {
return
}
legacyOpenAppID := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
if legacyOpenAppID != "" {
cfg.OpenAppID = legacyOpenAppID
}
}
legacyOpenAppSecret := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
if legacyOpenAppSecret != "" {
cfg.OpenAppSecret = legacyOpenAppSecret
}
}
legacyMPAppID := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
if legacyMPAppID != "" {
cfg.MPAppID = legacyMPAppID
}
}
legacyMPAppSecret := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
if legacyMPAppSecret != "" {
cfg.MPAppSecret = legacyMPAppSecret
}
}
if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") {
if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" {
cfg.FrontendRedirectURL = legacyFrontend
}
}
hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != ""
hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != ""
if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) {
cfg.Enabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen {
cfg.OpenEnabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP {
cfg.MPEnabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") {
switch {
case hasLegacyMP && !hasLegacyOpen:
cfg.Mode = "mp"
case hasLegacyOpen:
cfg.Mode = "open"
}
}
if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") {
switch {
case hasLegacyMP && !hasLegacyOpen:
cfg.Scopes = defaultWeChatConnectScopesForMode("mp")
case hasLegacyOpen:
cfg.Scopes = defaultWeChatConnectScopesForMode("open")
}
}
}
func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) {
if cfg == nil {
return
}
cfg.AppID = strings.TrimSpace(cfg.AppID)
cfg.AppSecret = strings.TrimSpace(cfg.AppSecret)
cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID)
cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret)
cfg.MPAppID = strings.TrimSpace(cfg.MPAppID)
cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret)
cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID)
cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret)
cfg.Mode = normalizeWeChatConnectMode(cfg.Mode)
cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL)
cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL)
cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID)
cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret)
cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID)
cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret)
cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID)
cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret)
cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID)
cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret)
if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled {
switch cfg.Mode {
case "mp":
cfg.MPEnabled = true
case "mobile":
cfg.MobileEnabled = true
default:
cfg.OpenEnabled = true
}
}
cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode)
cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode)
if cfg.FrontendRedirectURL == "" {
cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect
}
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@ -1012,6 +1253,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat)
normalizeWeChatConnectConfig(&cfg.WeChat)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
@ -1029,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
@ -1207,6 +1452,24 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// WeChat Connect OAuth 登录
viper.SetDefault("wechat_connect.enabled", false)
viper.SetDefault("wechat_connect.app_id", "")
viper.SetDefault("wechat_connect.app_secret", "")
viper.SetDefault("wechat_connect.open_app_id", "")
viper.SetDefault("wechat_connect.open_app_secret", "")
viper.SetDefault("wechat_connect.mp_app_id", "")
viper.SetDefault("wechat_connect.mp_app_secret", "")
viper.SetDefault("wechat_connect.mobile_app_id", "")
viper.SetDefault("wechat_connect.mobile_app_secret", "")
viper.SetDefault("wechat_connect.open_enabled", false)
viper.SetDefault("wechat_connect.mp_enabled", false)
viper.SetDefault("wechat_connect.mobile_enabled", false)
viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode)
viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes)
viper.SetDefault("wechat_connect.redirect_url", "")
viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect)
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
@ -1222,7 +1485,7 @@ func setDefaults() {
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.use_pkce", true)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
@ -1613,9 +1876,6 @@ func (c *Config) Validate() error {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.LinuxDo.Enabled {
if !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
@ -1667,13 +1927,46 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.WeChat.Enabled {
weChat := c.WeChat
normalizeWeChatConnectConfig(&weChat)
if weChat.OpenEnabled {
if strings.TrimSpace(weChat.OpenAppID) == "" {
return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true")
}
if strings.TrimSpace(weChat.OpenAppSecret) == "" {
return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true")
}
}
if weChat.MPEnabled {
if strings.TrimSpace(weChat.MPAppID) == "" {
return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true")
}
if strings.TrimSpace(weChat.MPAppSecret) == "" {
return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true")
}
}
if weChat.MobileEnabled {
if strings.TrimSpace(weChat.MobileAppID) == "" {
return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true")
}
if strings.TrimSpace(weChat.MobileAppSecret) == "" {
return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true")
}
}
if v := strings.TrimSpace(weChat.RedirectURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err)
}
warnIfInsecureURL("wechat_connect.redirect_url", v)
}
if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil {
return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL)
}
if c.OIDC.Enabled {
if !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true")
}
if !c.OIDC.ValidateIDToken {
return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true")
}
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}

View File

@ -225,6 +225,52 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback")
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.WeChat.Enabled)
require.True(t, cfg.WeChat.OpenEnabled)
require.True(t, cfg.WeChat.MPEnabled)
require.False(t, cfg.WeChat.MobileEnabled)
require.Equal(t, "open", cfg.WeChat.Mode)
require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID)
require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret)
require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID)
require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret)
require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL)
}
func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.OIDC.UsePKCE)
require.True(t, cfg.OIDC.ValidateIDToken)
require.False(t, cfg.OIDC.UsePKCEExplicit)
require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
cfg, err := Load()
require.NoError(t, err)
require.False(t, cfg.OIDC.UsePKCE)
require.False(t, cfg.OIDC.ValidateIDToken)
require.True(t, cfg.OIDC.UsePKCEExplicit)
require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
@ -346,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
@ -363,11 +409,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
if err != nil {
t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err)
}
}
@ -427,6 +470,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
}
}
func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
cfg.OIDC.TokenURL = "https://issuer.example.com/token"
cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo"
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.UsePKCE = false
cfg.OIDC.ValidateIDToken = false
cfg.OIDC.JWKSURL = ""
cfg.OIDC.AllowedSigningAlgs = ""
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err)
}
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t)

View File

@ -304,8 +304,8 @@ type UpdateSettingsRequest struct {
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
@ -565,6 +565,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL))
req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL))
if req.WeChatConnectMode == "" {
req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode))
}
if req.WeChatConnectScopes == "" {
req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes)
}
if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled {
response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time")
@ -598,9 +607,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID))
req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID))
req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID))
req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID))
if req.WeChatConnectOpenAppSecret == "" {
req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
@ -653,24 +662,31 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
}
}
if req.WeChatConnectRedirectURL == "" {
response.BadRequest(c, "WeChat Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
return
}
if req.WeChatConnectFrontendRedirectURL == "" {
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
}
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
return
if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
if req.WeChatConnectRedirectURL == "" {
response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
return
}
if req.WeChatConnectFrontendRedirectURL == "" {
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
}
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
return
}
}
}
// Generic OIDC 参数验证
oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
@ -689,10 +705,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
if req.OIDCConnectProviderName == "" {
req.OIDCConnectProviderName = "OIDC"
req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC"))
req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID))
req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL))
req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL))
req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL))
req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL))
req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL))
req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL))
req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile"))
req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL))
req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback"))
req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post")))
req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256"))
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath))
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath))
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath))
if req.OIDCConnectUsePKCE != nil {
oidcUsePKCE = *req.OIDCConnectUsePKCE
}
if req.OIDCConnectValidateIDToken != nil {
oidcValidateIDToken = *req.OIDCConnectValidateIDToken
}
if req.OIDCConnectClockSkewSeconds == 0 {
req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds
if req.OIDCConnectClockSkewSeconds == 0 {
req.OIDCConnectClockSkewSeconds = 120
}
}
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
@ -749,14 +790,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
if !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled")
return
}
if !req.OIDCConnectValidateIDToken {
response.BadRequest(c, "OIDC ID Token validation must be enabled")
return
}
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
@ -767,7 +800,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectAllowedSigningAlgs == "" {
if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
@ -1048,8 +1081,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,

View File

@ -247,6 +247,163 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
}
func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingPaymentVisibleMethodAlipayEnabled: "true",
service.SettingPaymentVisibleMethodAlipaySource: "",
service.SettingPaymentVisibleMethodWxpayEnabled: "false",
service.SettingPaymentVisibleMethodWxpaySource: "",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": false,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource])
require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
}
func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
"oidc_connect_use_pkce": false,
"oidc_connect_validate_id_token": false,
"oidc_connect_allowed_signing_algs": "",
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["oidc_connect_use_pkce"])
require.Equal(t, false, data["oidc_connect_validate_id_token"])
}
func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
},
}
svc := service.NewSettingService(repo, &config.Config{
Default: config.DefaultConfig{UserConcurrency: 5},
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
Scopes: "openid email profile",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
}
func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{

View File

@ -29,18 +29,19 @@ func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-31",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-31",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
},
}
}
handler := &AuthHandler{
userService: service.NewUserService(repo, nil, nil, nil),

View File

@ -78,9 +78,24 @@ type AuthResponse struct {
User *dto.User `json:"user"`
}
func ensureLoginUserActive(user *service.User) error {
if user == nil {
return infraerrors.Unauthorized("INVALID_USER", "user not found")
}
if !user.IsActive() {
return service.ErrUserNotActive
}
return nil
}
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token向后兼容
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
if err := ensureLoginUserActive(user); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
@ -293,6 +308,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := ensureLoginUserActive(user); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
@ -678,6 +697,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
h.consumePendingOAuthSessionOnLogout(c)
clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
@ -698,7 +719,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
return
}
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return

View File

@ -123,13 +123,16 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
codeChallenge := oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@ -200,10 +203,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
intent = normalizeOAuthIntent(intent)
codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@ -292,25 +298,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityKey,
TargetUserID: &user.ID,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@ -358,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
}
userEntity, err := client.User.Query().
Where(dbuser.EmailEqualFold(email)).
Only(ctx)
Where(userNormalizedEmailPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
}
return userEntity, nil
switch len(userEntity) {
case 0:
return nil, nil
case 1:
return userEntity[0], nil
default:
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
}
}
func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
@ -414,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
completionResponse["choice_reason"] = "force_email_on_signup"
}
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@ -472,6 +480,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
@ -484,12 +501,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
@ -497,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
@ -546,7 +566,9 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("code_verifier", codeVerifier)
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
@ -699,8 +721,10 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
@ -937,7 +961,19 @@ func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
Value: "",
Path: oauthBindAccessTokenCookiePath,
MaxAge: -1,
HttpOnly: false,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthBindAccessTokenCookieName,
Value: url.QueryEscape(strings.TrimSpace(token)),
Path: oauthBindAccessTokenCookiePath,
MaxAge: linuxDoOAuthCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
@ -1021,6 +1057,26 @@ func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (strin
return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
}
func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
const bearerPrefix = "Bearer "
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
token := strings.TrimSpace(authHeader[len(bearerPrefix):])
if token == "" {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
c.Status(http.StatusNoContent)
c.Writer.WriteHeaderNow()
}
func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return &subject.UserID, nil

View File

@ -5,6 +5,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@ -170,6 +171,80 @@ func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require.Equal(t, int64(42), userID)
}
func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
TokenURL: "https://connect.linux.do/oauth/token",
UserInfoURL: "https://connect.linux.do/api/user",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
handler.LinuxDoOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
}
func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
@ -226,6 +301,27 @@ func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
require.Equal(t, -1, accessTokenCookie.MaxAge)
}
func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
req.Header.Set("Authorization", "Bearer access-token-value")
c.Request = req
handler.PrepareOAuthBindAccessTokenCookie(c)
require.Equal(t, http.StatusNoContent, recorder.Code)
accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
require.NotNil(t, accessTokenCookie)
require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
require.True(t, accessTokenCookie.HttpOnly)
require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
}
func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@ -305,10 +401,81 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.NotEmpty(t, completion["access_token"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}
func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(linuxDoSyntheticEmail("654")).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("654").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@ -341,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("legacy@example.com").
SetEmail(" Legacy@Example.com ").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
@ -372,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, existingUser.Email, completion["email"])
require.Equal(t, existingUser.Email, completion["existing_account_email"])
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
require.Equal(t, true, completion["existing_account_bindable"])
require.Equal(t, "compat_email_match", completion["choice_reason"])
_, hasAccessToken := completion["access_token"]
@ -658,6 +826,186 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-choice-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-choice-subject-1").
SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-no-adoption-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-subject-no-adoption").
SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "LinuxDo Legacy",
"suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "linuxdo_user", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-conflict-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)

View File

@ -0,0 +1,68 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("logout-pending-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("logout-subject-123").
SetBrowserSessionKey("logout-browser-session-key").
SetResolvedEmail("logout@example.com").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
cookies := recorder.Result().Cookies()
for _, name := range []string{
oauthPendingSessionCookieName,
oauthPendingBrowserCookieName,
oauthBindAccessTokenCookieName,
linuxDoOAuthStateCookieName,
oidcOAuthStateCookieName,
wechatOAuthStateCookieName,
wechatPaymentOAuthStateName,
} {
cookie := findCookie(cookies, name)
require.NotNil(t, cookie, name)
require.Equal(t, -1, cookie.MaxAge, name)
require.True(t, cookie.HttpOnly, name)
}
storedSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}

View File

@ -265,16 +265,20 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
if len(payload) == 0 {
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
}
for _, key := range []string{"access_token", "refresh_token"} {
if value := pendingSessionStringValue(payload, key); value != "" {
return true
}
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false
}
return false
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
return false
}
if pendingSessionWantsInvitation(payload) {
return false
}
return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
}
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
@ -294,6 +298,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
return nil
}
func buildLegacyCompleteRegistrationPendingResponse(
session *dbent.PendingAuthSession,
forceEmailOnSignup bool,
emailVerificationRequired bool,
) map[string]any {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"create_account_allowed": true,
"force_email_on_signup": forceEmailOnSignup,
}))
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
if _, exists := completionResponse["email"]; !exists {
completionResponse["email"] = email
}
if _, exists := completionResponse["resolved_email"]; !exists {
completionResponse["resolved_email"] = email
}
}
if _, exists := completionResponse["choice_reason"]; !exists {
switch {
case forceEmailOnSignup:
completionResponse["choice_reason"] = "force_email_on_signup"
case emailVerificationRequired:
completionResponse["choice_reason"] = "email_verification_required"
default:
completionResponse["choice_reason"] = "third_party_signup"
}
}
return completionResponse
}
func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
c *gin.Context,
session *dbent.PendingAuthSession,
) (*dbent.PendingAuthSession, bool, error) {
if session == nil {
return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
if step := pendingSessionStringValue(payload, "step"); step != "" {
return session, true, nil
}
emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
if !emailVerificationRequired && !forceEmailOnSignup {
return session, false, nil
}
client := h.entClient()
if client == nil {
return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
updatedSession, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
strings.TrimSpace(session.ResolvedEmail),
nil,
buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
)
if err != nil {
return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
return updatedSession, true, nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
@ -376,15 +452,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
userEntity, err := client.User.Get(ctx, record.UserID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
return userEntity, nil
return findActiveUserByID(ctx, client, record.UserID)
}
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
@ -439,7 +507,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
email := strings.TrimSpace(strings.ToLower(req.Email))
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
@ -624,6 +692,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
return matches[0], nil
}
func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
if client == nil || session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity == nil || identity.UserID <= 0 {
return nil
}
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
if err != nil {
return err
}
if activeOwner != nil {
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return nil
}
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil {
return nil
@ -910,6 +1010,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64)
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
return nil, service.ErrUserNotActive
}
return userEntity, nil
}
@ -1123,6 +1226,38 @@ func consumePendingOAuthBrowserSessionTx(
return nil
}
func applyPendingOAuthAdoptionAndConsumeSession(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
userID int64,
) error {
if client == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
if session == nil || userID <= 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
return err
}
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
return err
}
return tx.Commit()
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
@ -1212,13 +1347,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
if session == nil || len(payload) == 0 {
return false, nil
}
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false, nil
}
if !pendingOAuthCompletionIncludesTokenPayload(payload) {
return false, nil
}
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
return false, nil
}
if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
@ -1262,6 +1391,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
return svc, session, clearCookies, nil
}
func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
if c == nil || c.Request == nil {
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
return
}
svc, err := h.pendingIdentityService()
if err != nil {
return
}
_, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
}
func clearOAuthLogoutCookies(c *gin.Context) {
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
clearOAuthBindAccessTokenCookie(c, secureCookie)
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
payload := gin.H{
@ -1280,6 +1462,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
normalized := clonePendingMap(payload)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
@ -1315,16 +1500,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c *gin.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
targetUser *dbent.User,
email string,
) (*dbent.PendingAuthSession, error) {
completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
var targetUserID *int64
if targetUser != nil && targetUser.ID > 0 {
targetUserID = &targetUser.ID
}
session, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
email,
nil,
targetUserID,
completionResponse,
)
if err != nil {
@ -1438,6 +1628,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
@ -1464,7 +1658,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
}
if existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
@ -1487,7 +1681,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
)
if err != nil {
if errors.Is(err, service.ErrEmailExists) {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
if lookupErr != nil {
response.ErrorFrom(c, lookupErr)
return
}
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
@ -1649,6 +1848,27 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
var loginUser *service.User
if canIssueTokenPair {
loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := ensureLoginUserActive(loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
}
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
if err != nil {
clearCookies()
@ -1658,25 +1878,6 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
if skipAdoptionPrompt {
delete(payload, "adoption_required")
}
if pendingOAuthCompletionIncludesTokenPayload(payload) {
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
clearCookies()
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
return
}
user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {
@ -1724,6 +1925,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
return
}
if canIssueTokenPair {
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
if err != nil {
clearCookies()
response.InternalError(c, "Failed to generate token pair")
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
payload["access_token"] = tokenPair.AccessToken
payload["refresh_token"] = tokenPair.RefreshToken
payload["expires_in"] = tokenPair.ExpiresIn
payload["token_type"] = "Bearer"
}
clearCookies()
response.Success(c, payload)
}

View File

@ -746,8 +746,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"refresh_token": "refresh-token",
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
@ -769,13 +769,23 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
require.Equal(t, http.StatusOK, recorder.Code)
payload := decodeJSONResponseData(t, recorder)
require.Equal(t, "access-token", payload["access_token"])
require.Equal(t, "refresh-token", payload["refresh_token"])
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
require.NotEqual(t, "legacy-access-token", payload["access_token"])
require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"])
require.Equal(t, "/dashboard", payload["redirect"])
require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
require.NotContains(t, payload, "adoption_required")
accessToken, ok := payload["access_token"].(string)
require.True(t, ok)
claims, err := handler.authService.ValidateToken(accessToken)
require.NoError(t, err)
reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID)
require.NoError(t, err)
require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
@ -785,6 +795,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.NotContains(t, completion, "access_token")
require.NotContains(t, completion, "refresh_token")
require.NotContains(t, completion, "expires_in")
require.NotContains(t, completion, "token_type")
require.Equal(t, "/dashboard", completion["redirect"])
}
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
@ -841,6 +859,72 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
require.Nil(t, storedSession.ConsumedAt)
}
func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("disabled-linked@example.com").
SetUsername("disabled-linked-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("disabled-linked-session-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("disabled-linked-subject").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("disabled-linked-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"suggested_display_name": "Disabled Linked User",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"redirect": "/dashboard",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")})
ginCtx.Request = req
handler.ExchangePendingOAuthCompletion(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
payload := normalizePendingOAuthCompletionResponse(map[string]any{
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
})
require.NotContains(t, payload, "access_token")
require.NotContains(t, payload, "refresh_token")
require.NotContains(t, payload, "expires_in")
require.NotContains(t, payload, "token_type")
require.Equal(t, "/dashboard", payload["redirect"])
}
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
@ -969,7 +1053,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
_, err := client.User.Create().
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
@ -1023,7 +1107,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, storedSession.Intent)
require.Nil(t, storedSession.TargetUserID)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
require.Nil(t, storedSession.ConsumedAt)
@ -1042,7 +1127,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
_, err := client.User.Create().
existingUser, err := client.User.Create().
SetEmail(" Owner@Example.com ").
SetUsername("owner-user").
SetPasswordHash("hash").
@ -1088,7 +1173,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.TargetUserID)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
@ -1096,7 +1182,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
_, err := client.User.Create().
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
@ -1144,7 +1230,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, storedSession.Intent)
require.Nil(t, storedSession.TargetUserID)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
@ -1202,6 +1289,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require.Nil(t, storedSession.ConsumedAt)
}
func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
handler, _ := newOAuthPendingFlowTestHandler(t, false)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
}
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
ctx := context.Background()
@ -1934,6 +2041,13 @@ func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
payload := decodeJSONResponseData(t, recorder)
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
accessToken, ok := payload["access_token"].(string)
require.True(t, ok)
claims, err := handler.authService.ValidateToken(accessToken)
require.NoError(t, err)
reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID)
require.NoError(t, err)
require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
identity, err := client.AuthIdentity.Query().
Where(

View File

@ -2,6 +2,7 @@ package handler
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
require.NoError(t, err)
return decoded
}
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
require.NotEmpty(t, location)
parsed, err := url.Parse(location)
require.NoError(t, err)
rawValues := parsed.RawQuery
if rawValues == "" {
rawValues = parsed.Fragment
}
values, err := url.ParseQuery(rawValues)
require.NoError(t, err)
require.Equal(t, errorCode, values.Get("error"))
require.Equal(t, errorMessage, values.Get("error_message"))
}

View File

@ -157,21 +157,25 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
}
codeChallenge := ""
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
return
if cfg.UsePKCE {
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
nonce := ""
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
if cfg.ValidateIDToken {
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@ -244,17 +248,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
intent = normalizeOAuthIntent(intent)
codeVerifier := ""
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
expectedNonce := ""
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
if cfg.ValidateIDToken {
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@ -284,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
var idClaims *oidcIDTokenClaims
if cfg.ValidateIDToken {
if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
}
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
subject := strings.TrimSpace(idClaims.Subject)
subject := ""
if idClaims != nil {
subject = strings.TrimSpace(idClaims.Subject)
}
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
issuer := strings.TrimSpace(idClaims.Issuer)
issuer := ""
if idClaims != nil {
issuer = strings.TrimSpace(idClaims.Issuer)
}
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
emailVerified := userInfoClaims.EmailVerified
if emailVerified == nil {
if emailVerified == nil && idClaims != nil {
emailVerified = idClaims.EmailVerified
}
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
return
}
identityKey := oidcIdentityKey(issuer, subject)
compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
compatEmail := strings.TrimSpace(userInfoClaims.Email)
if compatEmail == "" && idClaims != nil {
compatEmail = strings.TrimSpace(idClaims.Email)
}
email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
idClaims.Name,
func() string {
if idClaims != nil {
return idClaims.PreferredUsername
}
return ""
}(),
func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(),
oidcFallbackUsername(subject),
)
identityRef := service.PendingAuthIdentityKey{
@ -344,14 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderSubject: subject,
}
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
"email": email,
"username": username,
"subject": subject,
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(), username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
upstreamClaims["compat_email"] = compatEmail
@ -387,25 +422,16 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityRef,
TargetUserID: &user.ID,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@ -537,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
if compatEmailUser != nil {
resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
}
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@ -596,6 +627,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
@ -608,12 +648,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
@ -621,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
@ -670,7 +713,9 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("code_verifier", codeVerifier)
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
@ -872,9 +917,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
q.Set("nonce", nonce)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil

View File

@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require.Equal(t, int64(84), userID)
}
func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
Enabled: true,
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/oauth/authorize",
TokenURL: "https://issuer.example.com/oauth/token",
UserInfoURL: "https://issuer.example.com/oauth/userinfo",
Scopes: "openid profile email",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
ValidateIDToken: false,
RequireEmailVerified: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
handler.OIDCOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotContains(t, location, "code_challenge=")
require.NotContains(t, location, "nonce=")
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
}
func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
Enabled: true,
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "openid profile email",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
ValidateIDToken: false,
RequireEmailVerified: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-login",
@ -250,10 +333,63 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.NotEmpty(t, completion["access_token"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}
func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-disabled-subject",
PreferredUsername: "oidc_disabled",
DisplayName: "OIDC Disabled",
})
defer cleanup()
handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("oidc").
SetProviderKey(cfg.IssuerURL).
SetProviderSubject("oidc-disabled-subject").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled"))
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-compat",
@ -302,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
@ -606,6 +743,189 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-choice-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-choice-subject-1").
SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
SetBrowserSessionKey("oidc-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-no-adoption-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-subject-no-adoption").
SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid").
SetBrowserSessionKey("oidc-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
"suggested_display_name": "OIDC Legacy",
"suggested_avatar_url": "https://cdn.example/oidc-legacy.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "oidc_user", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example.com"),
authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-conflict-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-conflict-subject").
SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
SetBrowserSessionKey("oidc-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct {
Subject string
PreferredUsername string

View File

@ -0,0 +1,61 @@
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 29,
Email: "session@example.com",
Username: "session-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 7,
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := &AuthHandler{authService: authService}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29})
handler.RevokeAllSessions(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(8), repo.user.TokenVersion)
var resp struct {
Code int `json:"code"`
Data struct {
Message string `json:"message"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message)
}

View File

@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
@ -476,11 +471,12 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
}
func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
var legacyKey []byte
key, err := payment.ProvideEncryptionKey(h.cfg)
if err != nil {
return service.NewPaymentResumeService(nil)
if err == nil {
legacyKey = []byte(key)
}
return service.NewPaymentResumeService([]byte(key))
return service.NewLegacyAwarePaymentResumeService(legacyKey)
}
type completeWeChatOAuthRequest struct {
@ -530,6 +526,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
@ -547,7 +552,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
@ -823,7 +828,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
return user, err
if err != nil || user == nil {
return user, err
}
return findActiveUserByID(ctx, client, user.ID)
}
}
@ -847,7 +855,10 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
return user, err
if err != nil || user == nil {
return user, err
}
return findActiveUserByID(ctx, client, user.ID)
}
}
@ -866,7 +877,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
if err != nil {
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
return singleWeChatIdentityUser(records)
user, err := singleWeChatIdentityUser(records)
if err != nil || user == nil {
return user, err
}
return findActiveUserByID(ctx, client, user.ID)
}
func wechatCompatibleProviderKeys(providerKey string) []string {

View File

@ -213,6 +213,151 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
require.Equal(t, "third_party_signup", completion["choice_reason"])
}
func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(wechatSyntheticEmail("union-456")).
SetUsername("wechat-existing-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"username": "wechat-existing-user"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/dashboard", completion["redirect"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
}
func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(wechatSyntheticEmail("union-disabled")).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-disabled").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
@ -233,6 +378,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
defer client.Close()
handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
handler.cfg.Totp.EncryptionKeyConfigured = true
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -270,6 +416,67 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
}
func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`))
return
}
http.NotFound(w, r)
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
defer client.Close()
legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
explicitSigningKey := "explicit-payment-resume-signing-key"
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey)
handler.cfg.Totp.EncryptionKey = legacyKeyHex
handler.cfg.Totp.EncryptionKeyConfigured = true
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed"))
req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`))
req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
c.Request = req
handler.WeChatPaymentOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
token := fragment.Get("wechat_resume_token")
require.NotEmpty(t, token)
claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token)
require.NoError(t, err)
require.Equal(t, "openid-mixed-key", claims.OpenID)
require.Equal(t, payment.TypeWxpay, claims.PaymentType)
require.Equal(t, "18.8", claims.Amount)
require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
require.EqualValues(t, 9, claims.PlanID)
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
_, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token)
require.Error(t, err)
}
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string
@ -620,7 +827,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
require.Zero(t, count)
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
@ -693,27 +900,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder)
require.NotEmpty(t, responseData["access_token"])
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["adoption_required"])
require.Empty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "WeChat Display", userEntity.Username)
require.Nil(t, consumed.ConsumedAt)
identity, err := client.AuthIdentity.Query().
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
Count(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
require.Zero(t, identityCount)
channel, err := client.AuthIdentityChannel.Query().
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
@ -721,25 +933,82 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, channelCount)
decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Count(ctx)
require.NoError(t, err)
require.Zero(t, decisionCount)
}
func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-no-adoption-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("wechat-subject-no-adoption").
SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid").
SetBrowserSessionKey("wechat-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
"suggested_display_name": "WeChat Legacy",
"suggested_avatar_url": "https://cdn.example/wechat-legacy.png",
"mode": "open",
"channel": "open",
"channel_app_id": "wx-open-app",
"channel_subject": "openid-legacy",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
completeReq.Header.Set("Content-Type", "application/json")
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")})
completeCtx.Request = completeReq
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID)
require.Equal(t, "union-456", channel.Metadata["unionid"])
require.Equal(t, "wechat_user", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
@ -901,6 +1170,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-choice-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("wechat-choice-subject-1").
SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
SetBrowserSessionKey("wechat-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
@ -1083,18 +1408,6 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
}, client
}
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
require.Equal(t, errorCode, fragment.Get("error"))
require.Equal(t, errorMessage, fragment.Get("error_message"))
}
type wechatOAuthSettingRepoStub struct {
values map[string]string
}

View File

@ -2,9 +2,9 @@ package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
@ -454,29 +454,65 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
// PublicOrderResult is the limited order info returned by the public verify endpoint.
// No user details are exposed — only payment status information.
type PublicOrderResult struct {
ID int64 `json:"id"`
OutTradeNo string `json:"out_trade_no"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
ID int64 `json:"id"`
OutTradeNo string `json:"out_trade_no"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
PaidAt *time.Time `json:"paid_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
RefundAmount float64 `json:"refund_amount"`
RefundReason *string `json:"refund_reason,omitempty"`
RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
RefundRequestReason *string `json:"refund_request_reason,omitempty"`
PlanID *int64 `json:"plan_id,omitempty"`
}
var errPaymentPublicOrderVerifyRemoved = infraerrors.New(
http.StatusGone,
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED",
"public payment order verification by out_trade_no has been removed; use resume_token recovery instead",
).WithMetadata(map[string]string{
"replacement_endpoint": "/api/v1/payment/public/orders/resolve",
"replacement_field": "resume_token",
})
func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
return PublicOrderResult{
ID: order.ID,
OutTradeNo: order.OutTradeNo,
Amount: order.Amount,
PayAmount: order.PayAmount,
FeeRate: order.FeeRate,
PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status,
CreatedAt: order.CreatedAt,
ExpiresAt: order.ExpiresAt,
PaidAt: order.PaidAt,
CompletedAt: order.CompletedAt,
RefundAmount: order.RefundAmount,
RefundReason: order.RefundReason,
RefundRequestedAt: order.RefundRequestedAt,
RefundRequestedBy: order.RefundRequestedBy,
RefundRequestReason: order.RefundRequestReason,
PlanID: order.PlanID,
}
}
// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
// a compatibility path for older result pages and staggered deploys.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved)
var req VerifyOrderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, buildPublicOrderResult(order))
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
@ -493,15 +529,7 @@ func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Success(c, PublicOrderResult{
ID: order.ID,
OutTradeNo: order.OutTradeNo,
Amount: order.Amount,
PayAmount: order.PayAmount,
PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status,
})
response.Success(c, buildPublicOrderResult(order))
}
// requireAuth extracts the authenticated subject from the context.

View File

@ -4,16 +4,17 @@ package handler
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@ -74,7 +75,7 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
}
}
func TestVerifyOrderPublicReturnsGone(t *testing.T) {
func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
@ -90,6 +91,257 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-verify@example.com").
SetPasswordHash("hash").
SetUsername("public-verify-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(90.64).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-VERIFY").
SetOutTradeNo("legacy-order-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-verify").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
ID int64 `json:"id"`
OutTradeNo string `json:"out_trade_no"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
RefundAmount float64 `json:"refund_amount"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, order.ID, resp.Data.ID)
require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
require.Equal(t, 90.64, resp.Data.PayAmount)
require.Equal(t, 0.03, resp.Data.FeeRate)
require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
require.Equal(t, service.OrderStatusPending, resp.Data.Status)
require.Equal(t, 0.0, resp.Data.RefundAmount)
require.NotEmpty(t, resp.Data.CreatedAt)
require.NotEmpty(t, resp.Data.ExpiresAt)
}
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE").
SetOutTradeNo("resolve-order-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, float64(order.ID), resp.Data["id"])
require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"])
require.Equal(t, 100.0, resp.Data["amount"])
require.Equal(t, 103.0, resp.Data["pay_amount"])
require.Equal(t, 0.03, resp.Data["fee_rate"])
require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
require.Contains(t, resp.Data, "created_at")
require.Contains(t, resp.Data, "expires_at")
require.Contains(t, resp.Data, "refund_amount")
}
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-mismatch-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
SetOutTradeNo("resolve-order-mismatch-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID + 999,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
Message string `json:"message"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
gin.SetMode(gin.TestMode)
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
@ -98,17 +350,19 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
bytes.NewBufferString(`{"out_trade_no":" "}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusGone, recorder.Code)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp response.Response
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusGone, resp.Code)
require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason)
require.Contains(t, resp.Message, "removed")
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
}

View File

@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
return
}
updatedUser, err := h.userService.UnbindUserAuthProvider(
updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult(
c.Request.Context(),
subject.UserID,
c.Param("provider"),
@ -258,6 +258,12 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if unbound && h.authService != nil {
if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
}
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
@ -504,8 +510,12 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
thirdParty := thirdPartyIdentityProviders(identities)
var avatarSource *userProfileSourceContext
if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
avatarValue := strings.TrimSpace(user.AvatarURL)
for _, summary := range thirdParty {
if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) {
avatarSource = buildUserProfileSourceContext(summary.Provider)
break
}
}
usernameValue := strings.TrimSpace(user.Username)
@ -516,9 +526,6 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
break
}
}
if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
profileSources := map[string]*userProfileSourceContext{}
if avatarSource != nil {

View File

@ -253,7 +253,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey)
require.False(t, resp.Data.Identities.WeChat.Bound)
require.True(t, resp.Data.Identities.WeChat.CanBind)
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start")
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start")
}
func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
@ -270,18 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
},
}
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
@ -331,10 +332,102 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.Equal(t, "linuxdo", usernameSource["source"])
}
func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 22,
Email: "edited-profile@example.com",
Username: "custom-name",
Role: service.RoleUser,
Status: service.StatusActive,
AvatarURL: "https://cdn.example.com/custom.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-22",
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22})
handler.GetProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.NotContains(t, resp.Data, "avatar_source")
require.NotContains(t, resp.Data, "username_source")
require.NotContains(t, resp.Data, "profile_sources")
}
type userHandlerEmailCacheStub struct {
data *service.VerificationCodeData
}
type userHandlerRefreshTokenCacheStub struct {
revokedUserIDs []int64
}
func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
s.revokedUserIDs = append(s.revokedUserIDs, userID)
return nil
}
func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
return s.data, nil
}
@ -495,6 +588,98 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
require.Equal(t, false, linuxdoBinding["bound"])
}
func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 23,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "identity@example.com",
},
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-23",
},
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23})
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
handler.UnbindIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(5), repo.user.TokenVersion)
}
func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 24,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "identity@example.com",
},
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24})
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
handler.UnbindIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Empty(t, repo.unbound)
require.Empty(t, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(4), repo.user.TokenVersion)
}
func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
@ -587,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
require.Equal(t, "wechat", resp.Data.Provider)
require.Equal(t, "GET", resp.Data.Method)
require.True(t, resp.Data.UseBrowserRedirect)
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start")
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start")
require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
}

View File

@ -60,11 +60,6 @@ const (
wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
)
// WeChat Pay error codes.
const (
wxpayErrNoAuth = "NO_AUTH"
)
var (
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
return svc.Prepay(ctx, req)
@ -200,14 +195,7 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
case wxpayModeJSAPI:
return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen)
case wxpayModeH5:
resp, err := w.prepayH5(ctx, client, req, notifyURL, totalFen)
if err == nil {
return resp, nil
}
if strings.Contains(err.Error(), wxpayErrNoAuth) {
return nil, fmt.Errorf("wxpay h5 payments are not authorized for this merchant: %w", err)
}
return nil, err
return w.prepayH5(ctx, client, req, notifyURL, totalFen)
case wxpayModeNative:
return w.prepayNative(ctx, client, req, notifyURL, totalFen)
default:

View File

@ -8,6 +8,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"net/url"
"strings"
"testing"
@ -641,3 +642,68 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
}
}
func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) {
origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
origNativePrepay := wxpayNativePrepay
origH5Prepay := wxpayH5Prepay
t.Cleanup(func() {
wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
wxpayNativePrepay = origNativePrepay
wxpayH5Prepay = origH5Prepay
})
jsapiCalls := 0
nativeCalls := 0
h5Calls := 0
wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
jsapiCalls++
return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
}
wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
h5Calls++
return nil, nil, errors.New("NO_AUTH")
}
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
nativeCalls++
return &native.PrepayResponse{
CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"),
}, nil, nil
}
provider := &Wxpay{
config: map[string]string{
"appId": "wx123",
"mchId": "mch123",
},
coreClient: &core.Client{},
}
resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_100",
Amount: "66.88",
PaymentType: payment.TypeWxpay,
Subject: "Balance Recharge",
NotifyURL: "https://merchant.example/payment/notify",
ClientIP: "203.0.113.10",
IsMobile: true,
})
if err == nil {
t.Fatal("expected no-auth error, got nil")
}
if jsapiCalls != 0 {
t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
}
if h5Calls != 1 {
t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
}
if nativeCalls != 0 {
t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
}
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if !strings.Contains(err.Error(), "NO_AUTH") {
t.Fatalf("error = %v, want NO_AUTH", err)
}
}

View File

@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
if cfg.Totp.EncryptionKey == "" {
if cfg == nil {
slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable")
return nil, nil
}
keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey)
if keyHex == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
// Reject auto-generated TOTP keys for payment signing.
// They change across restarts/instances and can silently break resume-token flows.
if !cfg.Totp.EncryptionKeyConfigured {
slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens")
return nil, nil
}
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}

View File

@ -0,0 +1,62 @@
package payment
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: strings.Repeat("a", 64),
EncryptionKeyConfigured: false,
},
}
key, err := ProvideEncryptionKey(cfg)
if err != nil {
t.Fatalf("ProvideEncryptionKey returned error: %v", err)
}
if len(key) != 0 {
t.Fatalf("encryption key len = %d, want 0", len(key))
}
}
func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
EncryptionKeyConfigured: true,
},
}
key, err := ProvideEncryptionKey(cfg)
if err != nil {
t.Fatalf("ProvideEncryptionKey returned error: %v", err)
}
if len(key) != 32 {
t.Fatalf("encryption key len = %d, want 32", len(key))
}
}
func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: "abcd",
EncryptionKeyConfigured: true,
},
}
_, err := ProvideEncryptionKey(cfg)
if err == nil {
t.Fatal("expected error for invalid key length")
}
}

View File

@ -4,6 +4,7 @@ package repository
import (
"context"
"database/sql"
"os"
"path/filepath"
"strconv"
@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoMalformedUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
userIDs := make([]int64, 0, 8)
for _, email := range []string{
@ -643,6 +572,388 @@ FROM auth_identity_migration_reports
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
}
func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoFirstUserID))
var linuxDoSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoSecondUserID))
var wechatFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatFirstUserID))
var wechatSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatSecondUserID))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
RETURNING id
`, linuxDoFirstUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
RETURNING id
`, linuxDoSecondUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
RETURNING id
`, wechatFirstUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
RETURNING id
`, wechatSecondUserID).Scan(new(int64)))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var linuxDoIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-ambiguous-subject'
`).Scan(&linuxDoIdentityCount))
require.Zero(t, linuxDoIdentityCount)
var wechatIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-ambiguous-subject'
`).Scan(&wechatIdentityCount))
require.Zero(t, wechatIdentityCount)
var wechatChannelCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_channels
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND channel = 'oa'
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
`).Scan(&wechatChannelCount))
require.Zero(t, wechatChannelCount)
}
func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migration115SQL, err := os.ReadFile(migration115Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoFirstUserID))
var linuxDoSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoSecondUserID))
var wechatFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatFirstUserID))
var wechatSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatSecondUserID))
var linuxDoFirstLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
RETURNING id
`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
var linuxDoSecondLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
RETURNING id
`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
var wechatFirstLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
RETURNING id
`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
var wechatSecondLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
RETURNING id
`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
_, err = tx.ExecContext(ctx, string(migration115SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var identityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
`).Scan(&identityCount))
require.Zero(t, identityCount)
var conflictReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
require.Equal(t, 4, conflictReportCount)
var winnerAttributedReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
AND details ->> 'existing_identity_id' IS NOT NULL
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
require.Zero(t, winnerAttributedReportCount)
}
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
migration108aSQL, err := os.ReadFile(migration108aPath)
require.NoError(t, err)
migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
migration109SQL, err := os.ReadFile(migration109Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
_, err = tx.ExecContext(ctx, `
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(40);
`)
require.NoError(t, err)
var oidcSyntheticUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&oidcSyntheticUserID))
var linuxdoLegacyUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxdoLegacyUserID))
var invalidMetadataLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
RETURNING id
`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID))
_, err = tx.ExecContext(ctx, string(migration108aSQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration109SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var reportTypeWidth int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
`).Scan(&reportTypeWidth))
require.Equal(t, 80, reportTypeWidth)
var oidcSyntheticRecoveryReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
AND report_key = $1
`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount))
require.Equal(t, 1, oidcSyntheticRecoveryReportCount)
var invalidMetadataReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount))
require.Equal(t, 1, invalidMetadataReportCount)
}
func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) {
t.Helper()
_, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`)
require.NoError(t, err)
}
func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) {
t.Helper()
_, err := tx.ExecContext(ctx, `
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
pending_auth_sessions,
auth_identities,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
}

View File

@ -51,34 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
acceptedChecksums map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行避免放宽全局校验。
// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行
// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
"054_drop_legacy_cache_columns.sql": {
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
acceptedDBChecksum: map[string]struct{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
},
},
"061_add_usage_log_request_type.sql": {
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
acceptedDBChecksum: map[string]struct{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
"109_auth_identity_compat_backfill.sql": {
fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
acceptedDBChecksum: map[string]struct{}{
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {},
},
},
"054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
"061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
"109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
"110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"),
"112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"),
"115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
"116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
"118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"),
"119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
@ -205,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if nonTx {
if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
return fmt.Errorf("prepare migration %s: %w", name, err)
}
// *_notx.sql用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
@ -254,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil
}
func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
switch name {
case paymentOrdersOutTradeNoUniqueMigration:
return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
default:
return nil
}
}
func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
if err != nil {
return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
}
if len(duplicates) > 0 {
return fmt.Errorf(
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
paymentOrdersOutTradeNoUniqueMigration,
strings.Join(duplicates, ", "),
)
}
invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
if err != nil {
return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
}
if !invalid {
return nil
}
if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
}
return nil
}
func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
rows, err := db.QueryContext(ctx, `
SELECT out_trade_no, COUNT(*) AS duplicate_count
FROM payment_orders
WHERE out_trade_no <> ''
GROUP BY out_trade_no
HAVING COUNT(*) > 1
ORDER BY duplicate_count DESC, out_trade_no
LIMIT 5
`)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
duplicates := make([]string, 0, 5)
for rows.Next() {
var outTradeNo string
var duplicateCount int
if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
return nil, err
}
duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
}
if err := rows.Err(); err != nil {
return nil, err
}
return duplicates, nil
}
func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
var invalid bool
err := db.QueryRowContext(ctx, `
SELECT EXISTS (
SELECT 1
FROM pg_class idx
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
JOIN pg_index i ON i.indexrelid = idx.oid
WHERE ns.nspname = 'public'
AND idx.relname = $1
AND NOT i.indisvalid
)
`, indexName).Scan(&invalid)
return invalid, err
}
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
@ -328,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
func checksumSet(values ...string) map[string]struct{} {
out := make(map[string]struct{}, len(values))
for _, value := range values {
out[value] = struct{}{}
}
return out
}
func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule {
return migrationChecksumCompatibilityRule{
fileChecksum: fileChecksum,
acceptedDBChecksum: checksumSet(acceptedDBChecksums...),
acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...),
}
}
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
if rule.fileChecksum != fileChecksum {
_, dbOK := rule.acceptedChecksums[dbChecksum]
if !dbOK {
return false
}
_, ok = rule.acceptedDBChecksum[dbChecksum]
return ok
_, fileOK := rule.acceptedChecksums[fileChecksum]
return fileOK
}
func validateMigrationExecutionMode(name, content string) (bool, error) {

View File

@ -55,9 +55,110 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
t.Run("109历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
)
require.True(t, ok)
})
t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
)
require.True(t, ok)
})
t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
)
require.True(t, ok)
})
t.Run("110历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"110_pending_auth_and_provider_default_grants.sql",
"e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925",
"32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279",
)
require.True(t, ok)
})
t.Run("112历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"112_add_payment_order_provider_key_snapshot.sql",
"ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e",
"b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99",
)
require.True(t, ok)
})
t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"115_auth_identity_legacy_external_backfill.sql",
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
)
require.True(t, ok)
})
t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"116_auth_identity_legacy_external_safety_reports.sql",
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
)
require.True(t, ok)
})
t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"119_enforce_payment_orders_out_trade_no_unique.sql",
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e",
)
require.True(t, ok)
})
t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) {
for _, dbChecksum := range []string{
"a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb",
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227",
} {
ok := isMigrationChecksumCompatible(
"118_wechat_dual_mode_and_auth_source_defaults.sql",
dbChecksum,
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0",
)
require.True(t, ok)
}
})
t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
for _, dbChecksum := range []string{
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
} {
ok := isMigrationChecksumCompatible(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
dbChecksum,
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
)
require.True(t, ok)
}
})
t.Run("119未知checksum不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"119_enforce_payment_orders_out_trade_no_unique.sql",
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
)
require.False(t, ok)
})
}

View File

@ -94,6 +94,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
for _, name := range []string{
"109_auth_identity_compat_backfill.sql",
"110_pending_auth_and_provider_default_grants.sql",
"112_add_payment_order_provider_key_snapshot.sql",
"115_auth_identity_legacy_external_backfill.sql",
"116_auth_identity_legacy_external_safety_reports.sql",
"118_wechat_dual_mode_and_auth_source_defaults.sql",
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
} {
rule, ok := migrationChecksumCompatibilityRules[name]
require.Truef(t, ok, "missing compatibility rule for %s", name)
require.NotEmpty(t, rule.fileChecksum)
require.NotEmpty(t, rule.acceptedDBChecksum)
}
}
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()

View File

@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
Data: []byte(`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "duplicate out_trade_no")
require.Contains(t, err.Error(), "dup-out-trade-no")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("paymentorder_out_trade_no_unique").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
Data: []byte(`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)

View File

@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
tx := testTx(t)
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
requireConstraintDefinitionContains(
t,
tx,
"users",
"users_signup_source_check",
"signup_source",
"'email'",
"'linuxdo'",
"'wechat'",
"'oidc'",
)
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
@ -106,6 +135,118 @@ SELECT EXISTS (
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.False(t, exists, "expected index %s on %s to be absent", index, table)
}
func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
t.Helper()
var (
unique bool
def string
)
err := tx.QueryRowContext(context.Background(), `
SELECT
i.indisunique,
pg_get_indexdef(i.indexrelid)
FROM pg_class idx
JOIN pg_index i ON i.indexrelid = idx.oid
JOIN pg_class tbl ON tbl.oid = i.indrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND idx.relname = $2
`, table, index).Scan(&unique, &def)
require.NoError(t, err, "query index definition for %s.%s", table, index)
require.True(t, unique, "expected index %s on %s to be unique", index, table)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
}
}
func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
t.Helper()
var actual string
err := tx.QueryRowContext(context.Background(), `
SELECT CASE c.confdeltype
WHEN 'a' THEN 'NO ACTION'
WHEN 'r' THEN 'RESTRICT'
WHEN 'c' THEN 'CASCADE'
WHEN 'n' THEN 'SET NULL'
WHEN 'd' THEN 'SET DEFAULT'
END
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
WHERE ns.nspname = 'public'
AND c.contype = 'f'
AND tbl.relname = $1
AND attr.attname = $2
AND ref_tbl.relname = $3
LIMIT 1
`, table, column, refTable).Scan(&actual)
require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
}
func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
t.Helper()
var def string
err := tx.QueryRowContext(context.Background(), `
SELECT pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND c.conname = $2
`, table, constraint).Scan(&def)
require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
}
}
func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
t.Helper()
var columnDefault sql.NullString
err := tx.QueryRowContext(context.Background(), `
SELECT column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&columnDefault)
require.NoError(t, err, "query column_default for %s.%s", table, column)
require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
for _, fragment := range fragments {
require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
}
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()

View File

@ -4,11 +4,15 @@ import (
"context"
"database/sql"
"fmt"
"hash/fnv"
"reflect"
"sort"
"strings"
"sync"
"time"
"unsafe"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
type scopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*scopedKeyLockEntry
}
type scopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
return &scopedKeyLockRegistry{
locks: make(map[string]*scopedKeyLockEntry),
}
}
func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*scopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &scopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func advisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
release := repositoryScopedKeyLocks.lock(keys...)
normalized := normalizeLockKeys(keys...)
if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
if err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
if dbent.TxFromContext(ctx) != nil {
return fn(ctx)
@ -301,17 +412,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
client := clientFromContext(txCtx, r.client)
canonical := input.Canonical
identity, err := client.AuthIdentity.Query().
identityRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
).
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
All(txCtx)
if err != nil {
return err
}
if identity != nil && identity.UserID != input.UserID {
identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
return ErrAuthIdentityOwnershipConflict
}
if identity == nil {
@ -328,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return err
}
} else {
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
update := client.AuthIdentity.UpdateOneID(identity.ID)
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
update = update.SetProviderKey(targetProviderKey)
}
if input.Metadata != nil {
update = update.SetMetadata(copyMetadata(input.Metadata))
}
@ -346,20 +462,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Query().
channelRecords, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
).
WithIdentity().
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
All(txCtx)
if err != nil {
return err
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
return ErrAuthIdentityChannelOwnershipConflict
}
if channel == nil {
@ -376,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return err
}
} else {
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID)
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
update = update.SetProviderKey(targetProviderKey)
}
if input.ChannelMetadata != nil {
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
}
@ -397,6 +518,104 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return result, nil
}
func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records {
if record.UserID != userID {
continue
}
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records {
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
continue
}
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
@ -422,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
}
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
client := clientFromContext(ctx, r.client)
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(ctx); err != nil {
return nil, err
var result *dbent.IdentityAdoptionDecision
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
releaseLocks, err := lockRepositoryScopedKeys(
txCtx,
client,
txAwareSQLExecutor(txCtx, r.sql, r.client),
identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
)
if err != nil {
return err
}
defer releaseLocks()
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(txCtx); err != nil {
return err
}
}
}
current, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
now := time.Now().UTC()
if current == nil {
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(now)
if input.IdentityID != nil {
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return err
}
result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
return err
})
if err != nil {
return nil, err
}
return update.Save(ctx)
return result, nil
}
func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
if identityID != nil && *identityID > 0 {
keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
}
return keys
}
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {

View File

@ -10,6 +10,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
@ -186,6 +188,79 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() {
user := s.mustCreateUser("wechat-legacy-alias")
legacyIdentity, err := s.client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy-alias"}).
Save(s.ctx)
s.Require().NoError(err)
legacyChannel, err := s.client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("oa").
SetChannelAppID("wx-app-legacy").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy-alias"}).
Save(s.ctx)
s.Require().NoError(err)
bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
Channel: "oa",
ChannelAppID: "wx-app-legacy",
ChannelSubject: "openid-legacy-123",
},
Metadata: map[string]any{"source": "canonical-bind"},
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
})
s.Require().NoError(err)
s.Require().NotNil(bound)
s.Require().NotNil(bound.Identity)
s.Require().NotNil(bound.Channel)
s.Require().Equal(legacyIdentity.ID, bound.Identity.ID)
s.Require().Equal(legacyChannel.ID, bound.Channel.ID)
s.Require().Equal("wechat-main", bound.Identity.ProviderKey)
s.Require().Equal("wechat-main", bound.Channel.ProviderKey)
s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"])
s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"])
identityCount, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Equal(1, identityCount)
channelCount, err := s.client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("oa"),
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Equal(1, channelCount)
}
func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
user := s.mustCreateUser("provider-mismatch-create")

View File

@ -0,0 +1,212 @@
package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
user := &service.User{
Email: "wechat-legacy@example.com",
Username: "wechat-legacy",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, user))
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy-alias"}).
Save(ctx)
require.NoError(t, err)
legacyChannel, err := client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("oa").
SetChannelAppID("wx-app-legacy").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy-alias"}).
Save(ctx)
require.NoError(t, err)
bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
Channel: "oa",
ChannelAppID: "wx-app-legacy",
ChannelSubject: "openid-legacy-123",
},
Metadata: map[string]any{"source": "canonical-bind"},
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
})
require.NoError(t, err)
require.NotNil(t, bound)
require.NotNil(t, bound.Identity)
require.NotNil(t, bound.Channel)
require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
require.Equal(t, legacyChannel.ID, bound.Channel.ID)
require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("oa"),
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, channelCount)
}
func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
user := &service.User{
Email: "repo-adoption@example.com",
Username: "repo-adoption",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, user))
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-repo-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("pending-repo-adoption").
SetIntent("bind_current_user").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-repo-adoption").
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
SetLocalFlowState(map[string]any{"step": "pending"}).
Save(ctx)
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
require.True(t, loaded.AdoptDisplayName)
require.True(t, loaded.AdoptAvatar)
}

View File

@ -52,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
@ -64,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
return err
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
@ -76,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
@ -154,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
@ -165,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client
}
}
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
return err
}
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
@ -197,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(ctx)
updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
@ -704,8 +739,28 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
client = clientFromContext(ctx, client)
if client == nil {
return nil
}
matches, err := client.User.Query().
Where(userEmailLookupPredicate(email)).
All(ctx)
if err != nil {
return err
}
for _, match := range matches {
if match.ID != userID {
return service.ErrEmailExists
}
}
return nil
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := strings.ToLower(strings.TrimSpace(email))
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return dbuser.EmailEQ(email)
}
@ -719,6 +774,18 @@ func userEmailLookupPredicate(email string) predicate.User {
})
}
func normalizeEmailLookupValue(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func normalizedEmailUniquenessLockKey(email string) string {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return ""
}
return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
err := client.UserAllowedGroup.Create().
@ -853,11 +920,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
}
func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource)
if signupSource == "" {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"
}
return signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.

View File

@ -3,7 +3,10 @@ package repository
import (
"context"
"database/sql"
"fmt"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
@ -18,9 +21,10 @@ import (
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
db.SetMaxOpenConns(10)
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
@ -67,3 +71,157 @@ func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T)
require.NoError(t, err)
require.True(t, exists)
}
func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) {
repo, _ := newUserEntRepo(t)
ctx := context.Background()
err := repo.Create(ctx, &service.User{
Email: " Existing@Example.com ",
Username: "existing-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
require.NoError(t, err)
err = repo.Create(ctx, &service.User{
Email: "existing@example.com",
Username: "duplicate-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
require.ErrorIs(t, err, service.ErrEmailExists)
}
func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) {
repo, _ := newUserEntRepo(t)
ctx := context.Background()
first := &service.User{
Email: " Existing@Example.com ",
Username: "existing-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, first))
second := &service.User{
Email: "second@example.com",
Username: "second-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, second))
second.Email = " existing@example.com "
err := repo.Update(ctx, second)
require.ErrorIs(t, err, service.ErrEmailExists)
}
func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
_, err := client.User.Create().
SetEmail("Conflict@Example.com").
SetUsername("conflict-user-1").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.User.Create().
SetEmail(" conflict@example.com ").
SetUsername("conflict-user-2").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = repo.GetByEmail(ctx, "conflict@example.com")
require.Error(t, err)
require.ErrorContains(t, err, "normalized email lookup matched multiple users")
}
func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.User.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type createResult struct {
err error
}
results := make(chan createResult, 2)
go func() {
results <- createResult{err: repo.Create(ctx, &service.User{
Email: " Race@Example.com ",
Username: "race-user-1",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})}
}()
<-firstCreateStarted
go func() {
results <- createResult{err: repo.Create(ctx, &service.User{
Email: "race@example.com",
Username: "race-user-2",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
errors := []error{first.err, second.err}
successes := 0
conflicts := 0
for _, err := range errors {
switch err {
case nil:
successes++
case service.ErrEmailExists:
conflicts++
default:
t.Fatalf("unexpected create error: %v", err)
}
}
require.Equal(t, 1, successes)
require.Equal(t, 1, conflicts)
count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
}

View File

@ -85,7 +85,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
@ -93,7 +93,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
@ -101,7 +101,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
@ -122,7 +122,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
@ -130,7 +130,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
@ -138,7 +138,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
@ -159,7 +159,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
@ -167,7 +167,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
@ -175,7 +175,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"
@ -784,6 +784,198 @@ func TestAPIContracts(t *testing.T) {
}
}`,
},
{
name: "GET /api/v1/admin/settings falls back to config oauth defaults",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.cfg.OIDC = config.OIDCConnectConfig{
Enabled: true,
ProviderName: "ConfigOIDC",
ClientID: "oidc-config-client",
ClientSecret: "oidc-config-secret",
IssuerURL: "https://issuer.example.com",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256,ES256,PS256",
ClockSkewSeconds: 120,
}
deps.cfg.WeChat = config.WeChatConnectConfig{
Enabled: true,
OpenEnabled: true,
OpenAppID: "wx-open-config",
OpenAppSecret: "wx-open-secret",
Mode: "open",
Scopes: "snsapi_login",
FrontendRedirectURL: "/auth/wechat/callback",
}
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
})
},
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"invitation_code_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "",
"smtp_port": 587,
"smtp_username": "",
"smtp_password_configured": false,
"smtp_from_email": "",
"smtp_from_name": "",
"smtp_use_tls": false,
"turnstile_enabled": false,
"turnstile_site_key": "",
"turnstile_secret_key_configured": false,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"oidc_connect_enabled": true,
"oidc_connect_provider_name": "ConfigOIDC",
"oidc_connect_client_id": "oidc-config-client",
"oidc_connect_client_secret_configured": true,
"oidc_connect_issuer_url": "https://issuer.example.com",
"oidc_connect_discovery_url": "",
"oidc_connect_authorize_url": "",
"oidc_connect_token_url": "",
"oidc_connect_userinfo_url": "",
"oidc_connect_jwks_url": "",
"oidc_connect_scopes": "openid email profile",
"oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
"oidc_connect_require_email_verified": false,
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subscription to API Conversion Platform",
"api_base_url": "",
"contact_info": "",
"doc_url": "",
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50],
"custom_menu_items": [],
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_openai": "gpt-4o",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_antigravity": "gemini-2.5-pro",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
"ops_metrics_interval_seconds": 60,
"min_claude_code_version": "",
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"enable_cch_signing": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "",
"payment_visible_method_wxpay_source": "",
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_balance_recharge_multiplier": 0,
"payment_recharge_fee_rate": 0,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
"wechat_connect_mode": "open",
"wechat_connect_open_enabled": true,
"wechat_connect_open_app_id": "wx-open-config",
"wechat_connect_open_app_secret_configured": true,
"wechat_connect_mp_enabled": false,
"wechat_connect_mp_app_id": "wx-open-config",
"wechat_connect_mp_app_secret_configured": true,
"wechat_connect_mobile_enabled": false,
"wechat_connect_mobile_app_id": "wx-open-config",
"wechat_connect_mobile_app_secret_configured": true,
"wechat_connect_redirect_url": "",
"wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
"wechat_connect_scopes": "snsapi_login",
"auth_source_default_email_balance": 0,
"auth_source_default_email_concurrency": 5,
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
"auth_source_default_linuxdo_grant_on_signup": false,
"auth_source_default_linuxdo_grant_on_first_bind": false,
"auth_source_default_oidc_balance": 0,
"auth_source_default_oidc_concurrency": 5,
"auth_source_default_oidc_subscriptions": [],
"auth_source_default_oidc_grant_on_signup": false,
"auth_source_default_oidc_grant_on_first_bind": false,
"auth_source_default_wechat_balance": 0,
"auth_source_default_wechat_concurrency": 5,
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
"force_email_on_third_party_signup": false
}
}`,
},
{
name: "POST /api/v1/admin/accounts/bulk-update",
method: http.MethodPost,
@ -827,6 +1019,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
cfg *config.Config
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
@ -947,6 +1140,7 @@ func newContractDeps(t *testing.T) *contractDeps {
return &contractDeps{
now: now,
router: r,
cfg: cfg,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,

View File

@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
}
}
func backendModeAllowsAuthPath(path string) bool {
path = strings.ToLower(strings.TrimSpace(path))
for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
if strings.HasSuffix(path, suffix) {
return true
}
}
for _, suffix := range []string{
"/auth/oauth/linuxdo/callback",
"/auth/oauth/wechat/callback",
"/auth/oauth/wechat/payment/callback",
"/auth/oauth/oidc/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",
"/auth/oauth/linuxdo/create-account",
"/auth/oauth/wechat/create-account",
"/auth/oauth/oidc/create-account",
"/auth/oauth/linuxdo/bind-login",
"/auth/oauth/wechat/bind-login",
"/auth/oauth/oidc/bind-login",
} {
if strings.HasSuffix(path, suffix) {
return true
}
}
return strings.Contains(path, "/auth/oauth/pending/")
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
// Allows the minimal auth surface admins still need in backend mode, including
// OAuth callbacks and pending continuations. Handler-level backend mode checks
// still enforce admin-only login and forbid self-service registration.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
path := c.Request.URL.Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
if backendModeAllowsAuthPath(c.Request.URL.Path) {
c.Next()
return
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()

View File

@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_linuxdo_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_linuxdo_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_wechat_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_wechat_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_wechat_payment_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/payment/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_wechat_payment_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/payment/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_oidc_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_oidc_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",
path: "/api/v1/auth/oauth/pending/exchange",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_send_verify_code",
enabled: "true",
path: "/api/v1/auth/oauth/pending/send-verify-code",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/pending/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/pending/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_provider_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_provider_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_legacy_complete_registration",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/complete-registration",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_register",
enabled: "true",

View File

@ -63,8 +63,20 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
@ -129,6 +141,12 @@ func RegisterAuthRoutes(
h.Auth.CreateWeChatOAuthAccount,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
@ -164,23 +182,6 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
}
}

View File

@ -44,9 +44,9 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
// Signed resume-token recovery is the supported public lookup path.
// The legacy anonymous out_trade_no verify endpoint is kept only as a
// compatibility shim that returns HTTP 410 Gone.
// Signed resume-token recovery is the preferred public lookup path.
// The legacy anonymous out_trade_no verify endpoint remains available as a
// persisted-state compatibility path for staggered upgrades.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)

View File

@ -419,6 +419,7 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
_ = prompt
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID

View File

@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
}
canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
var issuer *string
if input.Issuer != nil {
@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
}
defer func() { _ = tx.Rollback() }()
identity, err := tx.AuthIdentity.Query().
identityRecords, err := tx.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderKeyIn(compatibleProviderKeys...),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil && identity.UserID != userID {
if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
if identity == nil {
create := tx.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderKey(canonicalProviderKey).
SetProviderSubject(providerSubject).
SetVerifiedAt(verifiedAt)
if issuer != nil {
@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
} else {
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
update := tx.AuthIdentity.UpdateOneID(identity.ID).
SetVerifiedAt(verifiedAt).
SetProviderKey(canonicalProviderKey)
if issuer != nil {
update = update.SetIssuer(*issuer)
}
@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
var channel *dbent.AuthIdentityChannel
if channelInput != nil {
channel, err = tx.AuthIdentityChannel.Query().
channelRecords, err := tx.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyEQ(providerKey),
authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
authidentitychannel.ChannelEQ(channelInput.Channel),
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
).
WithIdentity().
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
}
channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
if channel == nil {
create := tx.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderKey(canonicalProviderKey).
SetChannel(channelInput.Channel).
SetChannelAppID(channelInput.ChannelAppID).
SetChannelSubject(channelInput.ChannelSubject)
@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
} else {
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID).
SetProviderKey(canonicalProviderKey)
if channelInput.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return buildAdminBoundAuthIdentity(identity, channel), nil
}
func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records {
if record.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records {
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
if input == nil {
return nil

View File

@ -188,6 +188,93 @@ func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
require.Equal(t, "second", identities[0].Metadata["source"])
}
func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("wechat-alias@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy"}).
Save(ctx)
require.NoError(t, err)
legacyChannel, err := client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("open").
SetChannelAppID("wx-open").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy"}).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
Metadata: map[string]any{"source": "admin-repair"},
Channel: &AdminBindAuthIdentityChannelInput{
Channel: "open",
ChannelAppID: "wx-open",
ChannelSubject: "openid-legacy-123",
Metadata: map[string]any{"scene": "admin-repair"},
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "wechat-main", result.ProviderKey)
require.NotNil(t, result.Channel)
require.Equal(t, "open", result.Channel.Channel)
identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", identity.ProviderKey)
require.Equal(t, "admin-repair", identity.Metadata["source"])
channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", channel.ProviderKey)
require.Equal(t, legacyIdentity.ID, channel.IdentityID)
require.Equal(t, "admin-repair", channel.Metadata["scene"])
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, channelCount)
}
func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()

View File

@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// BindEmailIdentity verifies and binds a local email/password identity to the
@ -69,6 +70,7 @@ func (s *AuthService) BindEmailIdentity(
if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
return nil, err
}
s.revokeEmailIdentitySessions(ctx, userID)
return currentUser, nil
}
@ -87,6 +89,7 @@ func (s *AuthService) BindEmailIdentity(
}
}
s.revokeEmailIdentitySessions(ctx, userID)
return currentUser, nil
}
@ -219,6 +222,12 @@ func (s *AuthService) updateBoundEmailIdentityWithClient(
return nil
}
func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) {
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err)
}
}
func replaceBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,

View File

@ -14,10 +14,14 @@ import (
func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
switch signupSource {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return signupSource
default:
return "email"
}
return signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
signupSource = "email"
}
signupSource = normalizeOAuthSignupSource(signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
user := &User{
@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, user); err != nil {

View File

@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
require.Empty(t, redeemRepo.updateCalls)
}
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"",
" OIDC ",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
}
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fallback@example.com",
"secret-123",
"246810",
"",
"github",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "email", userRepo.created[0].SignupSource)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{

View File

@ -5,10 +5,15 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"sort"
"strings"
"sync"
"time"
"entgo.io/ent/dialect"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
entClient *dbent.Client
}
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
type authPendingIdentityScopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*authPendingIdentityScopedKeyLockEntry
}
type authPendingIdentityScopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
return &authPendingIdentityScopedKeyLockRegistry{
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
}
}
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &authPendingIdentityScopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func authPendingIdentityAdvisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
release := authPendingIdentityScopedKeyLocks.lock(keys...)
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
var rows entsql.Rows
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
if identityID != nil && *identityID > 0 {
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
}
return keys
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
@ -236,16 +357,66 @@ func (s *AuthPendingIdentityService) consumeSession(
return nil, err
}
sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
now := time.Now().UTC()
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
Where(
pendingauthsession.ConsumedAtIsNil(),
pendingauthsession.ExpiresAtGTE(now),
pendingauthsession.Or(
pendingauthsession.CompletionCodeExpiresAtIsNil(),
pendingauthsession.CompletionCodeExpiresAtGTE(now),
),
).
SetConsumedAt(now).
SetLocalFlowState(sanitizedLocalFlowState).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx)
if err != nil {
ClearCompletionCodeExpiresAt()
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey))
}
updated, err := update.Save(ctx)
if err == nil {
return updated, nil
}
if !dbent.IsNotFound(err) {
return nil, err
}
return updated, nil
current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID)
if currentErr != nil {
if dbent.IsNotFound(currentErr) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, currentErr
}
if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil {
return nil, err
}
return nil, consumedErr
}
func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
sanitized := copyPendingMap(localFlowState)
if len(sanitized) == 0 {
return sanitized
}
rawCompletion, ok := sanitized["completion_response"]
if !ok {
return sanitized
}
completion, ok := rawCompletion.(map[string]any)
if !ok {
return sanitized
}
cleanedCompletion := copyPendingMap(completion)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(cleanedCompletion, key)
}
sanitized["completion_response"] = cleanedCompletion
return sanitized
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
@ -274,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
return nil, fmt.Errorf("pending auth ent client is not configured")
}
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
client := s.entClient
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
client = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
client = existingTx.Client()
}
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
if err != nil {
return nil, err
}
defer releaseLocks()
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := s.entClient.IdentityAdoptionDecision.Update().
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
@ -287,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
}),
).
ClearIdentityID().
Save(ctx); err != nil {
Save(txCtx); err != nil {
return nil, err
}
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID)
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return nil, err
}
return update.Save(ctx)
decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
if err != nil {
return nil, err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
return decision, nil
}
func copyPendingMap(in map[string]any) map[string]any {

View File

@ -5,6 +5,7 @@ package service
import (
"context"
"database/sql"
"sync"
"testing"
"time"
@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
require.Nil(t, reloadedFirst.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption-concurrent@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-concurrent").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-concurrent",
},
})
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
@ -356,3 +458,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}
func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "stale-replay-subject",
},
BrowserSessionKey: "browser-session",
})
require.NoError(t, err)
loaded, err := svc.getBrowserSession(ctx, session.SessionToken)
require.NoError(t, err)
consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
_, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}
func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "legacy-token-subject",
},
BrowserSessionKey: "browser-session",
LocalFlowState: map[string]any{
"completion_response": map[string]any{
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
},
},
})
require.NoError(t, err)
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
stored, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
completion, ok := stored.LocalFlowState["completion_response"].(map[string]any)
require.True(t, ok)
require.NotContains(t, completion, "access_token")
require.NotContains(t, completion, "refresh_token")
require.NotContains(t, completion, "expires_in")
require.NotContains(t, completion, "token_type")
require.Equal(t, "/dashboard", completion["redirect"])
}

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
if claims.TokenVersion != user.TokenVersion {
if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion密码更改后所有Token失效
if data.TokenVersion != user.TokenVersion {
if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
@ -1467,8 +1470,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
// Access/refresh token verification both depend on TokenVersion, so bumping it provides
// immediate revocation even if refresh-token cache cleanup later fails.
func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
user.TokenVersion++
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
}
return nil
}
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func resolvedTokenVersion(user *User) int64 {
if user == nil {
return 0
}
if user.TokenVersionResolved {
return user.TokenVersion
}
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
sum := sha256.Sum256([]byte(material))
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
return user.TokenVersion ^ fingerprint
}

View File

@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"errors"
"sync"
"testing"
"time"
@ -13,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
@ -54,6 +56,16 @@ func newAuthServiceForEmailBind(
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
}
func newAuthServiceForEmailBindWithRefreshCache(
t *testing.T,
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
refreshTokenCache service.RefreshTokenCache,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t
require.Equal(t, 0, newIdentityCount)
}
func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
ctx := context.Background()
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
refreshTokenCache := newEmailBindRefreshTokenCacheStub()
userRepo := newEmailBindUserRepoStub(&service.User{
ID: 41,
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
Username: "legacy-user",
PasswordHash: "old-hash",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
})
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-bind-email-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
}, "")
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
storedUser, err := userRepo.GetByID(ctx, 41)
require.NoError(t, err)
require.Equal(t, "new@example.com", storedUser.Email)
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
_, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
require.ErrorIs(t, err, service.ErrTokenRevoked)
_, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
}
type emailBindSettingRepoStub struct {
values map[string]string
}
@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
type emailBindRefreshTokenCacheStub struct {
mu sync.Mutex
tokens map[string]*service.RefreshTokenData
userSets map[int64]map[string]struct{}
families map[string]map[string]struct{}
}
func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
return &emailBindRefreshTokenCacheStub{
tokens: make(map[string]*service.RefreshTokenData),
userSets: make(map[int64]map[string]struct{}),
families: make(map[string]map[string]struct{}),
}
}
func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
cloned := *data
s.tokens[tokenHash] = &cloned
return nil
}
func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
s.mu.Lock()
defer s.mu.Unlock()
data, ok := s.tokens[tokenHash]
if !ok {
return nil, service.ErrRefreshTokenNotFound
}
cloned := *data
return &cloned, nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, tokenHash)
for _, tokenSet := range s.userSets {
delete(tokenSet, tokenHash)
}
for _, tokenSet := range s.families {
delete(tokenSet, tokenHash)
}
return nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for tokenHash := range s.userSets[userID] {
delete(s.tokens, tokenHash)
for _, tokenSet := range s.families {
delete(tokenSet, tokenHash)
}
}
delete(s.userSets, userID)
return nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
s.mu.Lock()
defer s.mu.Unlock()
for tokenHash := range s.families[familyID] {
delete(s.tokens, tokenHash)
for _, tokenSet := range s.userSets {
delete(tokenSet, tokenHash)
}
}
delete(s.families, familyID)
return nil
}
func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.userSets[userID] == nil {
s.userSets[userID] = make(map[string]struct{})
}
s.userSets[userID][tokenHash] = struct{}{}
return nil
}
func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.families[familyID] == nil {
s.families[familyID] = make(map[string]struct{})
}
s.families[familyID][tokenHash] = struct{}{}
return nil
}
func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
tokenSet := s.userSets[userID]
out := make([]string, 0, len(tokenSet))
for tokenHash := range tokenSet {
out = append(out, tokenHash)
}
return out, nil
}
func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
tokenSet := s.families[familyID]
out := make([]string, 0, len(tokenSet))
for tokenHash := range tokenSet {
out = append(out, tokenHash)
}
return out, nil
}
func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
_, ok := s.families[familyID][tokenHash]
return ok, nil
}
type emailBindUserRepoStub struct {
mu sync.Mutex
usersByID map[int64]*service.User
usersByEmail map[string]*service.User
}
func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
cloned := cloneEmailBindUser(user)
return &emailBindUserRepoStub{
usersByID: map[int64]*service.User{
cloned.ID: cloned,
},
usersByEmail: map[string]*service.User{
cloned.Email: cloned,
},
}
}
func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.usersByID[id]
if !ok {
return nil, service.ErrUserNotFound
}
return cloneEmailBindUser(user), nil
}
func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.usersByEmail[email]
if !ok {
return nil, service.ErrUserNotFound
}
return cloneEmailBindUser(user), nil
}
func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.usersByID[user.ID]
if !ok {
return service.ErrUserNotFound
}
delete(s.usersByEmail, existing.Email)
cloned := cloneEmailBindUser(user)
s.usersByID[user.ID] = cloned
s.usersByEmail[cloned.Email] = cloned
return nil
}
func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
return nil
}
func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
_, ok := s.usersByEmail[email]
return ok, nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
return nil
}
func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
func cloneEmailBindUser(user *service.User) *service.User {
if user == nil {
return nil
}
cloned := *user
return &cloned
}

View File

@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances)
typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
if len(typeInstances) == 0 {
return typeInstances
}
@ -44,11 +44,25 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym
for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
matching := filterEnabledVisibleMethodInstances(instances, method)
if len(matching) != 1 {
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil {
delete(filtered, method)
continue
}
filtered[method] = []*dbent.PaymentProviderInstance{matching[0]}
if providerKey == "" {
if len(matching) == 0 {
delete(filtered, method)
continue
}
filtered[method] = matching
continue
}
selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
if len(selectedInstances) == 0 {
delete(filtered, method)
continue
}
filtered[method] = selectedInstances
}
return filtered
}

View File

@ -6,6 +6,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
func TestUnionFloat(t *testing.T) {
@ -301,7 +302,109 @@ func TestPcInstanceTypeLimits(t *testing.T) {
})
}
func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) {
func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
tests := []struct {
name string
sourceSetting string
wantAlipaySingleMin float64
wantAlipaySingleMax float64
wantGlobalMin float64
wantGlobalMax float64
}{
{
name: "official source",
sourceSetting: VisibleMethodSourceOfficialAlipay,
wantAlipaySingleMin: 10,
wantAlipaySingleMax: 100,
wantGlobalMin: 10,
wantGlobalMax: 300,
},
{
name: "easypay source",
sourceSetting: VisibleMethodSourceEasyPayAlipay,
wantAlipaySingleMin: 20,
wantAlipaySingleMax: 200,
wantGlobalMin: 20,
wantGlobalMax: 300,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
},
},
}
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
if !ok {
t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
}
if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
}
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
}
})
}
}
func TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
@ -313,20 +416,18 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official alipay instance: %v", err)
}
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetName("EasyPay Mixed").
SetConfig("{}").
SetSupportedTypes("alipay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
SetSupportedTypes("alipay,wxpay").
SetLimits(`{"alipay":{"singleMin":20,"singleMax":200},"wxpay":{"singleMin":40,"singleMax":400}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay alipay instance: %v", err)
}
require.NoError(t, err)
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
@ -335,31 +436,26 @@ func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testi
SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
require.NoError(t, err)
svc := &PaymentConfigService{
entClient: client,
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{}},
}
resp, err := svc.GetAvailableMethodLimits(ctx)
if err != nil {
t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
}
require.NoError(t, err)
if _, ok := resp.Methods[payment.TypeAlipay]; ok {
t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay])
}
alipayLimits, ok := resp.Methods[payment.TypeAlipay]
require.True(t, ok, "expected alipay limits to remain visible")
require.Equal(t, 10.0, alipayLimits.SingleMin)
require.Equal(t, 200.0, alipayLimits.SingleMax)
wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
if !ok {
t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
}
if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
}
if resp.GlobalMin != 30 || resp.GlobalMax != 300 {
t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax)
}
require.True(t, ok, "expected wxpay limits to remain visible")
require.Equal(t, 30.0, wxpayLimits.SingleMin)
require.Equal(t, 400.0, wxpayLimits.SingleMax)
require.Equal(t, 10.0, resp.GlobalMin)
require.Equal(t, 400.0, resp.GlobalMax)
}

View File

@ -116,6 +116,17 @@ var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
// providerPendingOrderProtectedConfigFields lists config keys that cannot be
// changed while the instance has in-progress orders. This includes secrets plus
// all provider identity fields that are snapshotted into orders or used by
// webhook/refund verification.
var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}, "pid": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
fields, ok := providerSensitiveConfigFields[providerKey]
if !ok {
@ -125,6 +136,28 @@ func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
return found
}
func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool {
fields, ok := providerPendingOrderProtectedConfigFields[providerKey]
if !ok {
return false
}
for fieldName := range fields {
if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) {
return true
}
}
return false
}
func providerConfigFieldValue(config map[string]string, fieldName string) string {
for key, value := range config {
if strings.EqualFold(key, fieldName) {
return value
}
}
return ""
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
@ -190,6 +223,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
var pendingOrderCount *int
getPendingOrderCount := func() (int, error) {
if pendingOrderCount != nil {
return *pendingOrderCount, nil
}
count, err := s.countPendingOrders(ctx, id)
if err != nil {
return 0, fmt.Errorf("check pending orders: %w", err)
}
pendingOrderCount = &count
return count, nil
}
nextEnabled := current.Enabled
if req.Enabled != nil {
nextEnabled = *req.Enabled
@ -201,18 +246,20 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil {
return nil, err
}
var mergedConfig map[string]string
if req.Config != nil {
hasSensitive := false
for k, v := range req.Config {
if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) {
hasSensitive = true
break
}
currentConfig, err := s.decryptConfig(current.Config)
if err != nil {
return nil, fmt.Errorf("decrypt existing config: %w", err)
}
if hasSensitive {
count, err := s.countPendingOrders(ctx, id)
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) {
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
@ -221,9 +268,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if req.Enabled != nil && !*req.Enabled {
count, err := s.countPendingOrders(ctx, id)
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
@ -237,13 +284,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if req.Enabled != nil {
finalEnabled = *req.Enabled
}
var mergedConfig map[string]string
if req.Config != nil {
mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
}
if finalEnabled {
configToValidate := mergedConfig
if configToValidate == nil {
@ -269,9 +309,9 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
count, err := s.countPendingOrders(ctx, id)
count, err := getPendingOrderCount()
if err != nil {
return nil, fmt.Errorf("check pending orders: %w", err)
return nil, err
}
if count > 0 {
// Load current instance to compare types

View File

@ -4,8 +4,16 @@ package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"strconv"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -199,7 +207,7 @@ func TestJoinTypes(t *testing.T) {
}
}
func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) {
func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
t.Parallel()
ctx := context.Background()
@ -227,15 +235,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test
_, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "alipay",
Name: "Official Alipay",
Config: map[string]string{"appId": "app-1"},
Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
SupportedTypes: []string{"alipay"},
Enabled: true,
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
require.NoError(t, err)
}
func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) {
func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
t.Parallel()
ctx := context.Background()
@ -264,7 +271,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: "wxpay",
Name: "Official WeChat",
Config: map[string]string{"appId": "wx-app"},
Config: validWxpayProviderConfig(t),
SupportedTypes: []string{"wxpay"},
Enabled: false,
})
@ -273,8 +280,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t
_, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
Enabled: boolPtrValue(true),
})
require.Error(t, err)
require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err))
require.NoError(t, err)
}
func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
@ -314,6 +320,289 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
require.Equal(t, "alipay,wxpay", saved.SupportedTypes)
}
func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
createConfig func(*testing.T) map[string]string
supportedType []string
updateConfig map[string]string
fieldName string
wantValue string
}{
{
name: "wxpay appId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"appId": "wx-app-updated"},
fieldName: "appId",
wantValue: "wx-app-test",
},
{
name: "wxpay mpAppId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfigWithJSAPIAppID,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"},
fieldName: "mpAppId",
wantValue: "wx-mp-app-test",
},
{
name: "wxpay mchId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"mchId": "mch-updated"},
fieldName: "mchId",
wantValue: "mch-test",
},
{
name: "wxpay publicKeyId",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"},
fieldName: "publicKeyId",
wantValue: "public-key-id-test",
},
{
name: "wxpay certSerial",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"certSerial": "cert-serial-updated"},
fieldName: "certSerial",
wantValue: "cert-serial-test",
},
{
name: "alipay appId",
providerKey: payment.TypeAlipay,
createConfig: validAlipayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"appId": "alipay-app-updated"},
fieldName: "appId",
wantValue: "alipay-app-test",
},
{
name: "easypay pid",
providerKey: payment.TypeEasyPay,
createConfig: validEasyPayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"pid": "pid-updated"},
fieldName: "pid",
wantValue: "pid-test",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: tc.providerKey,
Name: "protected-config-instance",
Config: tc.createConfig(t),
SupportedTypes: tc.supportedType,
Enabled: true,
})
require.NoError(t, err)
createPendingProviderConfigOrder(t, ctx, client, instance)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: tc.updateConfig,
})
require.Nil(t, updated)
require.Error(t, err)
require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err))
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Equal(t, tc.wantValue, cfg[tc.fieldName])
})
}
}
func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerKey string
createConfig func(*testing.T) map[string]string
supportedType []string
updateConfig map[string]string
fieldName string
wantValue string
}{
{
name: "wxpay notifyUrl",
providerKey: payment.TypeWxpay,
createConfig: validWxpayProviderConfig,
supportedType: []string{payment.TypeWxpay},
updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"},
fieldName: "notifyUrl",
wantValue: "https://merchant.example.com/wxpay/notify-v2",
},
{
name: "alipay same appId",
providerKey: payment.TypeAlipay,
createConfig: validAlipayProviderConfig,
supportedType: []string{payment.TypeAlipay},
updateConfig: map[string]string{"appId": "alipay-app-test"},
fieldName: "appId",
wantValue: "alipay-app-test",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
svc := &PaymentConfigService{
entClient: client,
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
}
instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
ProviderKey: tc.providerKey,
Name: "safe-config-instance",
Config: tc.createConfig(t),
SupportedTypes: tc.supportedType,
Enabled: true,
})
require.NoError(t, err)
createPendingProviderConfigOrder(t, ctx, client, instance)
updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
Config: tc.updateConfig,
})
require.NoError(t, err)
require.NotNil(t, updated)
saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
require.NoError(t, err)
cfg, err := svc.decryptConfig(saved.Config)
require.NoError(t, err)
require.Equal(t, tc.wantValue, cfg[tc.fieldName])
})
}
}
func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
t.Helper()
user, err := client.User.Create().
SetEmail("provider-config-pending@example.com").
SetPasswordHash("hash").
SetUsername("provider-config-pending-user").
Save(ctx)
require.NoError(t, err)
instanceID := strconv.FormatInt(instance.ID, 10)
_, err = client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID).
SetOutTradeNo("sub2_pending_provider_config_" + instanceID).
SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
SetProviderInstanceID(instanceID).
SetProviderKey(instance.ProviderKey).
Save(ctx)
require.NoError(t, err)
}
func providerPendingOrderPaymentType(providerKey string) string {
switch providerKey {
case payment.TypeWxpay:
return payment.TypeWxpay
case payment.TypeAlipay:
return payment.TypeAlipay
default:
return payment.TypeAlipay
}
}
func boolPtrValue(v bool) *bool {
return &v
}
func validAlipayProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"appId": "alipay-app-test",
"privateKey": "alipay-private-key-test",
"notifyUrl": "https://merchant.example.com/alipay/notify",
"returnUrl": "https://merchant.example.com/alipay/return",
}
}
func validEasyPayProviderConfig(t *testing.T) map[string]string {
t.Helper()
return map[string]string{
"pid": "pid-test",
"pkey": "pkey-test",
"apiBase": "https://pay.example.com",
"notifyUrl": "https://merchant.example.com/easypay/notify",
"returnUrl": "https://merchant.example.com/easypay/return",
}
}
func validWxpayProviderConfig(t *testing.T) map[string]string {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
privDER, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
require.NoError(t, err)
return map[string]string{
"appId": "wx-app-test",
"mchId": "mch-test",
"privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
"apiV3Key": "12345678901234567890123456789012",
"publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
"publicKeyId": "public-key-id-test",
"certSerial": "cert-serial-test",
}
}
func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string {
t.Helper()
cfg := validWxpayProviderConfig(t)
cfg["mpAppId"] = "wx-mp-app-test"
return cfg
}

View File

@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
if !isValidProviderAmount(paid) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{
"expected": o.PayAmount,
"paid": paid,
"tradeNo": tradeNo,
})
return fmt.Errorf("invalid paid amount from provider: %v", paid)
}
// Use order's expected amount when provider didn't report one
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
paid = o.PayAmount
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func isValidProviderAmount(amount float64) bool {
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
return validateProviderSnapshotMetadata(order, providerKey, metadata)
}

View File

@ -5,6 +5,7 @@ package service
import (
"context"
"errors"
"math"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
assert.False(t, ok)
}
func TestIsValidProviderAmount(t *testing.T) {
t.Parallel()
assert.True(t, isValidProviderAmount(0.01))
assert.False(t, isValidProviderAmount(0))
assert.False(t, isValidProviderAmount(-1))
assert.False(t, isValidProviderAmount(math.NaN()))
assert.False(t, isValidProviderAmount(math.Inf(1)))
}
func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
t.Parallel()

View File

@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
outTradeNo, err := s.allocateOutTradeNo(ctx, tx)
if err != nil {
return nil, err
}
providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
SetOutTradeNo(generateOutTradeNo()).
SetOutTradeNo(outTradeNo).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return order, nil
}
func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) {
const maxAttempts = 5
for attempt := 0; attempt < maxAttempts; attempt++ {
candidate := generateOutTradeNo()
exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx)
if err != nil {
return "", fmt.Errorf("check out_trade_no uniqueness: %w", err)
}
if !exists {
return candidate, nil
}
}
return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts)
}
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
@ -360,13 +379,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost)
canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
if err != nil {
return nil, err
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
if resume.isSigningConfigured() {
if canonicalReturnURL != "" && resume.isSigningConfigured() {
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
@ -380,7 +399,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
}
}
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, resumeToken)
providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, outTradeNo, resumeToken)
if err != nil {
return nil, err
}
@ -482,6 +501,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r
if err != nil {
return nil, err
}
if err := s.paymentResume().ensureSigningKey(); err != nil {
return nil, err
}
authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
if err != nil {

View File

@ -31,3 +31,68 @@ func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *tes
t.Fatal("expected official wxpay visible method to be detected from enabled provider instance")
}
}
func TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
tests := []struct {
name string
source string
wantOfficial bool
}{
{
name: "official source selected",
source: VisibleMethodSourceOfficialWechat,
wantOfficial: true,
},
{
name: "easypay source selected",
source: VisibleMethodSourceEasyPayWechat,
wantOfficial: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeWxpay).
SetName("Official WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official wxpay instance: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay WeChat").
SetConfig("{}").
SetSupportedTypes("wxpay").
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay wxpay instance: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodWxpaySource: tt.source,
},
},
},
}
if got := svc.usesOfficialWxpayVisibleMethod(ctx); got != tt.wantOfficial {
t.Fatalf("usesOfficialWxpayVisibleMethod() = %v, want %v", got, tt.wantOfficial)
}
})
}
}

View File

@ -150,6 +150,20 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if !isValidProviderAmount(resp.Amount) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{
"expected": o.PayAmount,
"paid": resp.Amount,
"tradeNo": resp.TradeNo,
"queryRef": queryRef,
})
slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount)
retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef)
if !retryOK {
return ""
}
resp = retriedResp
}
notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().
@ -174,6 +188,21 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) {
if prov == nil || strings.TrimSpace(queryRef) == "" {
return nil, false
}
resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err)
return nil, false
}
if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) {
return nil, false
}
return resp, true
}
func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
if order == nil {
return ""
@ -224,6 +253,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
@ -251,6 +284,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
if err != nil {
return nil, err
}
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
@ -260,6 +297,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin
return o, nil
}
func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
outTradeNo := strings.TrimSpace(raw)
if outTradeNo == "" {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
}
if len(outTradeNo) > 64 {
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
for _, ch := range outTradeNo {
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)

View File

@ -21,6 +21,8 @@ import (
type paymentOrderLifecycleQueryProvider struct {
lastQueryTradeNo string
queryCalls int
responses []*payment.QueryOrderResponse
resp *payment.QueryOrderResponse
}
@ -48,6 +50,14 @@ func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, paym
func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
p.lastQueryTradeNo = tradeNo
p.queryCalls++
if len(p.responses) > 0 {
resp := p.responses[0]
if len(p.responses) > 1 {
p.responses = p.responses[1:]
}
return resp, nil
}
return p.resp, nil
}
@ -234,6 +244,194 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-retry@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-retry-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-UPSTREAM-RETRY").
SetOutTradeNo("sub2_checkpaid_retry_zero_amount").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
require.Equal(t, user.ID, id)
if userRepo.getByIDUser != nil {
userRepo.getByIDUser.Balance += amount
}
return nil
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
responses: []*payment.QueryOrderResponse{
{
TradeNo: "upstream-trade-zero",
Status: payment.ProviderStatusPaid,
Amount: 0,
},
{
TradeNo: "upstream-trade-retry",
Status: payment.ProviderStatusPaid,
Amount: 88,
},
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, 2, provider.queryCalls)
require.Equal(t, OrderStatusCompleted, got.Status)
require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo)
}
func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-zero-amount@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-zero-amount-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-ZERO-AMOUNT").
SetOutTradeNo("sub2_checkpaid_zero_amount").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-zero",
Status: payment.ProviderStatusPaid,
Amount: 0,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, OrderStatusPending, got.Status)
require.Empty(t, got.PaymentTradeNo)
reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
require.NoError(t, err)
require.Equal(t, OrderStatusPending, reloaded.Status)
require.Empty(t, reloaded.PaymentTradeNo)
require.Equal(t, 0.0, userRepo.getByIDUser.Balance)
require.Empty(t, redeemRepo.useCalls)
}
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"strings"
"testing"
"time"
@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
}
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
@ -159,6 +162,83 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) {
t.Parallel()
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Intentionally missing payment resume signing key.
encryptionKey: nil,
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if err == nil {
t.Fatal("expected error, got nil")
}
appErr := infraerrors.FromError(err)
if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" {
t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Legacy stable signing key remains available for no-config upgrade compatibility.
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if resp == nil {
t.Fatal("expected oauth-required response, got nil")
}
if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
}
if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
@ -189,7 +269,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
return &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: values},
settingRepo: &paymentConfigSettingRepoStub{values: values},
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
}

View File

@ -6,6 +6,7 @@ import (
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
return nil, fmt.Errorf("get order by resume token: %w", err)
}
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
return nil, invalidResumeTokenMatchError()
}
snapshot := psOrderProviderSnapshot(order)
orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
}
}
if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
return nil, invalidResumeTokenMatchError()
}
if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
return nil, invalidResumeTokenMatchError()
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
return nil, fmt.Errorf("resume token payment type mismatch")
if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
return nil, invalidResumeTokenMatchError()
}
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
result := s.checkPaid(ctx, order)
@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil
}
func invalidResumeTokenMatchError() error {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
}
func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
_, err = svc.GetPublicOrderByResumeToken(ctx, token)
require.Error(t, err)
require.Contains(t, err.Error(), "resume token")
require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
}
func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
require.Equal(t, order.ID, got.ID)
require.Equal(t, 0, provider.queryCount)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
svc := &PaymentService{
entClient: newPaymentConfigServiceTestClient(t),
}
_, err := svc.VerifyOrderPublic(context.Background(), " ")
require.Error(t, err)
require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
}

View File

@ -1,6 +1,7 @@
package service
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct {
type PaymentResumeService struct {
signingKey []byte
verifyKeys [][]byte
}
type visibleMethodLoadBalancer struct {
@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct {
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
svc := &PaymentResumeService{}
if len(signingKey) > 0 {
svc.signingKey = append([]byte(nil), signingKey...)
svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
}
for _, fallback := range verifyFallbacks {
if len(fallback) == 0 {
continue
}
cloned := append([]byte(nil), fallback...)
duplicate := false
for _, existing := range svc.verifyKeys {
if bytes.Equal(existing, cloned) {
duplicate = true
break
}
}
if !duplicate {
svc.verifyKeys = append(svc.verifyKeys, cloned)
}
}
return svc
}
func (s *PaymentResumeService) isSigningConfigured() bool {
@ -209,7 +232,7 @@ func visibleMethodSourceSettingKey(method string) string {
}
}
func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
func CanonicalizeReturnURL(raw string, srcHost string, srcURL string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
@ -228,13 +251,29 @@ func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
if parsed.Path != paymentResultReturnPath {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
}
if !sameOriginHost(parsed.Host, srcHost) {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site")
if !allowedReturnURLHost(parsed.Host, srcHost, srcURL) {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site or browser origin")
}
return parsed.String(), nil
}
func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) {
func allowedReturnURLHost(returnURLHost string, requestHost string, refererURL string) bool {
if sameOriginHost(returnURLHost, requestHost) {
return true
}
refererURL = strings.TrimSpace(refererURL)
if refererURL == "" {
return false
}
parsedReferer, err := url.Parse(refererURL)
if err != nil || parsedReferer.Host == "" {
return false
}
return sameOriginHost(returnURLHost, parsedReferer.Host)
}
func buildPaymentReturnURL(base string, orderID int64, outTradeNo string, resumeToken string) (string, error) {
canonical := strings.TrimSpace(base)
if canonical == "" {
return "", nil
@ -253,6 +292,9 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
if orderID > 0 {
query.Set("order_id", strconv.FormatInt(orderID, 10))
}
if strings.TrimSpace(outTradeNo) != "" {
query.Set("out_trade_no", strings.TrimSpace(outTradeNo))
}
if strings.TrimSpace(resumeToken) != "" {
query.Set("resume_token", strings.TrimSpace(resumeToken))
}
@ -391,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
if !s.verifySignature(parts[0], parts[1]) {
return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
@ -401,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
return json.Unmarshal(payload, dest)
}
func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
if s == nil {
return false
}
for _, key := range s.verifyKeys {
if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
return true
}
}
return false
}
func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
if expiresAt <= 0 {
return nil
@ -412,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
}
func (s *PaymentResumeService) sign(payload string) string {
mac := hmac.New(sha256.New, s.signingKey)
return signPaymentResumePayload(payload, s.signingKey)
}
func signPaymentResumePayload(payload string, key []byte) string {
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func TestNormalizeVisibleMethods(t *testing.T) {
@ -64,7 +65,7 @@ func TestNormalizePaymentSource(t *testing.T) {
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com")
got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com", "")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
@ -76,7 +77,7 @@ func TestCanonicalizeReturnURL(t *testing.T) {
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil {
if _, err := CanonicalizeReturnURL("/payment/result", "example.com", ""); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
@ -84,15 +85,31 @@ func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil {
if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com", ""); err == nil {
t.Fatal("CanonicalizeReturnURL should reject external hosts")
}
}
func TestCanonicalizeReturnURLAllowsConfiguredFrontendHost(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL(
"https://app.example.com/payment/result?from=checkout",
"api.example.com",
"https://app.example.com/purchase",
)
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://app.example.com/payment/result?from=checkout" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://app.example.com/payment/result?from=checkout")
}
}
func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil {
if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com", ""); err == nil {
t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
}
}
@ -100,7 +117,7 @@ func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token")
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "sub2_42", "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
@ -119,6 +136,9 @@ func TestBuildPaymentReturnURL(t *testing.T) {
if query.Get("order_id") != strconv.FormatInt(42, 10) {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("out_trade_no") != "sub2_42" {
t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
}
if query.Get("resume_token") != "resume-token" {
t.Fatalf("resume_token = %q", query.Get("resume_token"))
}
@ -127,10 +147,34 @@ func TestBuildPaymentReturnURL(t *testing.T) {
}
}
func TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result", 42, "sub2_42", "")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
parsed, err := url.Parse(got)
if err != nil {
t.Fatalf("url.Parse returned error: %v", err)
}
query := parsed.Query()
if query.Get("order_id") != "42" {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("out_trade_no") != "sub2_42" {
t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
}
if query.Get("resume_token") != "" {
t.Fatalf("resume_token = %q, want empty", query.Get("resume_token"))
}
}
func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("", 42, "resume-token")
got, err := buildPaymentReturnURL("", 42, "sub2_42", "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
@ -290,6 +334,98 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
}
}
func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
svc := &PaymentService{
configService: &PaymentConfigService{
encryptionKey: legacyKey,
},
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
}
}
func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
legacyKey := []byte("0123456789abcdef0123456789abcdef")
svc := newLegacyAwarePaymentResumeService(legacyKey)
explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-explicit-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if explicitClaims.OpenID != "openid-explicit-key" {
t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key")
}
legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-legacy-key",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if legacyClaims.OpenID != "openid-legacy-key" {
t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
@ -376,6 +512,258 @@ func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
}
}
func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method payment.PaymentType
officialName string
officialTypes string
easyPayName string
easyPayTypes string
sourceSetting string
wantProvider string
}{
{
name: "alipay uses official source",
method: payment.TypeAlipay,
officialName: "Official Alipay",
officialTypes: "alipay",
easyPayName: "EasyPay Alipay",
easyPayTypes: "alipay",
sourceSetting: VisibleMethodSourceOfficialAlipay,
wantProvider: payment.TypeAlipay,
},
{
name: "alipay uses easypay source",
method: payment.TypeAlipay,
officialName: "Official Alipay",
officialTypes: "alipay",
easyPayName: "EasyPay Alipay",
easyPayTypes: "alipay",
sourceSetting: VisibleMethodSourceEasyPayAlipay,
wantProvider: payment.TypeEasyPay,
},
{
name: "wxpay uses official source",
method: payment.TypeWxpay,
officialName: "Official WeChat",
officialTypes: "wxpay",
easyPayName: "EasyPay WeChat",
easyPayTypes: "wxpay",
sourceSetting: VisibleMethodSourceOfficialWechat,
wantProvider: payment.TypeWxpay,
},
{
name: "wxpay uses easypay source",
method: payment.TypeWxpay,
officialName: "Official WeChat",
officialTypes: "wxpay",
easyPayName: "EasyPay WeChat",
easyPayTypes: "wxpay",
sourceSetting: VisibleMethodSourceEasyPayWechat,
wantProvider: payment.TypeEasyPay,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
officialProviderKey := payment.TypeAlipay
if tt.method == payment.TypeWxpay {
officialProviderKey = payment.TypeWxpay
}
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey).
SetName(tt.officialName).
SetConfig("{}").
SetSupportedTypes(tt.officialTypes).
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official provider: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName(tt.easyPayName).
SetConfig("{}").
SetSupportedTypes(tt.easyPayTypes).
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
visibleMethodSourceSettingKey(tt.method): tt.sourceSetting,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != tt.wantProvider {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, tt.wantProvider)
}
})
}
}
func TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeAlipay).
SetName("Official Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official provider: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
visibleMethodSourceSettingKey(payment.TypeAlipay): "",
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 9.9)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != "" {
t.Fatalf("lastProviderKey = %q, want legacy cross-provider empty key", inner.lastProviderKey)
}
if inner.lastPaymentType != payment.TypeAlipay {
t.Fatalf("lastPaymentType = %q, want %q", inner.lastPaymentType, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method payment.PaymentType
sourceValue string
wantMessage string
}{
{
name: "invalid wxpay source",
method: payment.TypeWxpay,
sourceValue: "stripe",
wantMessage: "wxpay source must be one of the supported payment providers",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
officialProviderKey := payment.TypeAlipay
officialSupportedTypes := "alipay"
officialName := "Official Alipay"
easyPaySupportedTypes := "alipay"
easyPayName := "EasyPay Alipay"
if tt.method == payment.TypeWxpay {
officialProviderKey = payment.TypeWxpay
officialSupportedTypes = "wxpay"
officialName = "Official WeChat"
easyPaySupportedTypes = "wxpay"
easyPayName = "EasyPay WeChat"
}
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(officialProviderKey).
SetName(officialName).
SetConfig("{}").
SetSupportedTypes(officialSupportedTypes).
SetEnabled(true).
SetSortOrder(1).
Save(ctx)
if err != nil {
t.Fatalf("create official provider: %v", err)
}
_, err = client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName(easyPayName).
SetConfig("{}").
SetSupportedTypes(easyPaySupportedTypes).
SetEnabled(true).
SetSortOrder(2).
Save(ctx)
if err != nil {
t.Fatalf("create easypay provider: %v", err)
}
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
visibleMethodSourceSettingKey(tt.method): tt.sourceValue,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 9.9)
if err == nil {
t.Fatal("SelectInstance should reject invalid visible method source configuration")
}
if infraerrors.Reason(err) != "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE" {
t.Fatalf("Reason(err) = %q, want %q", infraerrors.Reason(err), "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE")
}
if infraerrors.Message(err) != tt.wantMessage {
t.Fatalf("Message(err) = %q, want %q", infraerrors.Message(err), tt.wantMessage)
}
})
}
}
func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) {
t.Parallel()

View File

@ -1,10 +1,14 @@
package service
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"log/slog"
"math/rand/v2"
"os"
"strings"
"sync"
"time"
@ -44,6 +48,8 @@ const (
orderIDPrefix = "sub2_"
)
const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
@ -179,7 +185,7 @@ type PaymentService struct {
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}
@ -259,16 +265,56 @@ func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
return psNewPaymentResumeService(s.configService)
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
return newLegacyAwarePaymentResumeService(legacyKey)
}
func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService))
}
func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey)
return NewPaymentResumeService(signingKey, verifyFallbacks...)
}
func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) {
signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
if len(signingKey) == 0 {
if len(legacyKey) == 0 {
return nil, nil
}
return legacyKey, nil
}
if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
return signingKey, nil
}
return signingKey, [][]byte{legacyKey}
}
func parsePaymentResumeSigningKey(raw string) []byte {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if len(raw) >= 64 && len(raw)%2 == 0 {
if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
return decoded
}
}
return []byte(raw)
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"strings"
@ -82,19 +83,52 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta
return filtered
}
func buildPaymentProviderConflictError(method string, conflicting *dbent.PaymentProviderInstance) error {
metadata := map[string]string{
"payment_method": NormalizeVisibleMethod(method),
func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
for _, inst := range instances {
if !providerSupportsVisibleMethod(inst, method) {
continue
}
if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
continue
}
filtered = append(filtered, inst)
}
if conflicting != nil {
metadata["conflicting_provider_id"] = fmt.Sprintf("%d", conflicting.ID)
metadata["conflicting_provider_key"] = conflicting.ProviderKey
metadata["conflicting_provider_name"] = conflicting.Name
return filtered
}
func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
seen := make(map[string]struct{}, len(instances))
keys := make([]string, 0, len(instances))
for _, inst := range instances {
if inst == nil {
continue
}
key := strings.TrimSpace(inst.ProviderKey)
if key == "" {
continue
}
normalized := strings.ToLower(key)
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
keys = append(keys, key)
}
return infraerrors.Conflict(
"PAYMENT_PROVIDER_CONFLICT",
fmt.Sprintf("%s payment already has an enabled provider instance", NormalizeVisibleMethod(method)),
).WithMetadata(metadata)
return keys
}
func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return nil
}
for _, inst := range instances {
if strings.EqualFold(strings.TrimSpace(inst.ProviderKey), providerKey) {
return inst
}
}
return nil
}
func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
@ -104,33 +138,72 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
supportedTypes string,
enabled bool,
) error {
if s == nil || s.entClient == nil || !enabled {
return nil
}
// Visible methods are selected by configured source (official/easypay),
// so multiple enabled providers can intentionally claim the same user-facing
// method. Order creation and limits will route through the configured source.
_, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
return nil
}
claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes)
if len(claimedMethods) == 0 {
return nil
}
query := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true))
if excludeID > 0 {
query = query.Where(paymentproviderinstance.IDNEQ(excludeID))
}
instances, err := query.All(ctx)
if err != nil {
return fmt.Errorf("query enabled payment providers: %w", err)
}
for _, method := range claimedMethods {
for _, inst := range instances {
if providerSupportsVisibleMethod(inst, method) {
return buildPaymentProviderConflictError(method, inst)
func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context.Context, method string) (string, error) {
method = NormalizeVisibleMethod(method)
sourceKey := visibleMethodSourceSettingKey(method)
rawSource := ""
if s != nil && s.settingRepo != nil && sourceKey != "" {
value, err := s.settingRepo.GetValue(ctx, sourceKey)
if err != nil {
if !errors.Is(err, ErrSettingNotFound) {
return "", fmt.Errorf("get %s: %w", sourceKey, err)
}
} else {
rawSource = value
}
}
return nil
normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true)
if err != nil {
return "", err
}
if normalizedSource == "" {
return "", nil
}
providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource)
if !ok {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source must be one of the supported payment providers", method),
)
}
return providerKey, nil
}
func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
ctx context.Context,
method string,
matching []*dbent.PaymentProviderInstance,
) (string, error) {
switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
case 0:
return "", nil
case 1:
return strings.TrimSpace(providerKeys[0]), nil
default:
providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
if err != nil {
return "", err
}
if providerKey == "" {
return "", nil
}
selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
if selected == nil {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source has no enabled provider instance", method),
)
}
return strings.TrimSpace(selected.ProviderKey), nil
}
}
func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
@ -155,12 +228,15 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
}
matching := filterEnabledVisibleMethodInstances(instances, method)
switch len(matching) {
case 0:
return nil, nil
case 1:
return matching[0], nil
default:
return nil, buildPaymentProviderConflictError(method, matching[0])
providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
if err != nil {
return nil, err
}
if providerKey == "" {
if len(matching) == 0 {
return nil, nil
}
return &dbent.PaymentProviderInstance{ProviderKey: ""}, nil
}
return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
}

View File

@ -245,15 +245,119 @@ func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bo
}
func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
mode = normalizeWeChatConnectModeSetting(mode)
switch mode {
case "open":
if openEnabled {
return "open"
}
case "mp":
if mpEnabled {
return "mp"
}
case "mobile":
if mobileEnabled {
return "mobile"
}
}
switch {
case openEnabled:
return "open"
case mpEnabled:
return "mp"
case mobileEnabled:
return "mobile"
case openEnabled:
return "open"
default:
return normalizeWeChatConnectModeSetting(mode)
return mode
}
}
func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) {
mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode))
rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled]
rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled]
rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled]
openConfigured := hasOpen && strings.TrimSpace(rawOpen) != ""
mpConfigured := hasMP && strings.TrimSpace(rawMP) != ""
mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != ""
if openConfigured || mpConfigured || mobileConfigured {
openEnabled := strings.TrimSpace(rawOpen) == "true"
mpEnabled := strings.TrimSpace(rawMP) == "true"
mobileEnabled := strings.TrimSpace(rawMobile) == "true"
_, enabledConfigured := settings[SettingKeyWeChatConnectEnabled]
if !enabledConfigured &&
enabled &&
!openEnabled &&
!mpEnabled &&
!mobileEnabled &&
(base.OpenEnabled || base.MPEnabled || base.MobileEnabled) {
return base.OpenEnabled, base.MPEnabled, base.MobileEnabled
}
return openEnabled, mpEnabled, mobileEnabled
}
if !enabled {
return false, false, false
}
if base.OpenEnabled || base.MPEnabled || base.MobileEnabled {
return base.OpenEnabled, base.MPEnabled, base.MobileEnabled
}
return parseWeChatConnectCapabilitySettings(settings, enabled, mode)
}
func (s *SettingService) effectiveWeChatConnectOAuthConfig(settings map[string]string) WeChatConnectOAuthConfig {
base := config.WeChatConnectConfig{}
if s != nil && s.cfg != nil {
base = s.cfg.WeChat
}
enabled := base.Enabled
if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok {
enabled = strings.TrimSpace(raw) == "true"
}
legacyAppID := strings.TrimSpace(firstNonEmpty(
settings[SettingKeyWeChatConnectAppID],
base.AppID,
base.OpenAppID,
base.MPAppID,
base.MobileAppID,
))
legacyAppSecret := strings.TrimSpace(firstNonEmpty(
settings[SettingKeyWeChatConnectAppSecret],
base.AppSecret,
base.OpenAppSecret,
base.MPAppSecret,
base.MobileAppSecret,
))
openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], base.OpenAppID, legacyAppID))
openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], base.OpenAppSecret, legacyAppSecret))
mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], base.MPAppID, legacyAppID))
mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], base.MPAppSecret, legacyAppSecret))
mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], base.MobileAppID, legacyAppID))
mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], base.MobileAppSecret, legacyAppSecret))
modeRaw := firstNonEmpty(settings[SettingKeyWeChatConnectMode], base.Mode)
openEnabled, mpEnabled, mobileEnabled := mergeWeChatConnectCapabilitySettings(settings, base, enabled, modeRaw)
mode := normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, modeRaw)
return WeChatConnectOAuthConfig{
Enabled: enabled,
LegacyAppID: legacyAppID,
LegacyAppSecret: legacyAppSecret,
OpenAppID: openAppID,
OpenAppSecret: openAppSecret,
MPAppID: mpAppID,
MPAppSecret: mpAppSecret,
MobileAppID: mobileAppID,
MobileAppSecret: mobileAppSecret,
OpenEnabled: openEnabled,
MPEnabled: mpEnabled,
MobileEnabled: mobileEnabled,
Mode: mode,
Scopes: normalizeWeChatConnectScopeSetting(firstNonEmpty(settings[SettingKeyWeChatConnectScopes], base.Scopes), mode),
RedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectRedirectURL], base.RedirectURL)),
FrontendRedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectFrontendRedirectURL], base.FrontendRedirectURL, defaultWeChatConnectFrontend)),
}
}
@ -535,32 +639,7 @@ func DefaultWeChatConnectScopesForMode(mode string) string {
}
func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) {
enabled := settings[SettingKeyWeChatConnectEnabled] == "true"
mode := normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode])
openEnabled, mpEnabled, mobileEnabled := parseWeChatConnectCapabilitySettings(settings, enabled, mode)
mode = normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, mode)
cfg := WeChatConnectOAuthConfig{
Enabled: enabled,
LegacyAppID: strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]),
LegacyAppSecret: strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]),
OpenAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], settings[SettingKeyWeChatConnectAppID])),
OpenAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], settings[SettingKeyWeChatConnectAppSecret])),
MPAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], settings[SettingKeyWeChatConnectAppID])),
MPAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], settings[SettingKeyWeChatConnectAppSecret])),
MobileAppID: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], settings[SettingKeyWeChatConnectAppID])),
MobileAppSecret: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], settings[SettingKeyWeChatConnectAppSecret])),
OpenEnabled: openEnabled,
MPEnabled: mpEnabled,
MobileEnabled: mobileEnabled,
Mode: mode,
Scopes: normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], mode),
RedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]),
FrontendRedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]),
}
if cfg.FrontendRedirectURL == "" {
cfg.FrontendRedirectURL = defaultWeChatConnectFrontend
}
cfg := s.effectiveWeChatConnectOAuthConfig(settings)
if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) {
return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
@ -589,14 +668,10 @@ func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]strin
return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured")
}
}
if cfg.RedirectURL == "" {
return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
}
if cfg.FrontendRedirectURL == "" {
return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(cfg.RedirectURL); err != nil {
return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid")
if v := strings.TrimSpace(cfg.RedirectURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid")
}
}
if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid")
@ -605,33 +680,16 @@ func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]strin
}
func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) {
if settings[SettingKeyWeChatConnectEnabled] != "true" {
cfg := s.effectiveWeChatConnectOAuthConfig(settings)
if !cfg.Enabled {
return false, false, false, false
}
mode := normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode])
openEnabled, mpEnabled, mobileEnabled := parseWeChatConnectCapabilitySettings(settings, true, mode)
redirectURL := strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL])
frontendRedirectURL := strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL])
if frontendRedirectURL == "" {
frontendRedirectURL = defaultWeChatConnectFrontend
}
openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != ""
mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != ""
mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != ""
legacyAppID := strings.TrimSpace(settings[SettingKeyWeChatConnectAppID])
legacyAppSecret := strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret])
openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], legacyAppID))
openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], legacyAppSecret))
mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], legacyAppID))
mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], legacyAppSecret))
mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], legacyAppID))
mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], legacyAppSecret))
webRedirectReady := redirectURL != "" && frontendRedirectURL != ""
openReady := openEnabled && webRedirectReady && openAppID != "" && openAppSecret != ""
mpReady := mpEnabled && webRedirectReady && mpAppID != "" && mpAppSecret != ""
mobileReady := mobileEnabled && mobileAppID != "" && mobileAppSecret != ""
return openReady || mpReady || mobileReady, openReady, mpReady, mobileReady
return openReady || mpReady, openReady, mpReady, mobileReady
}
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
@ -756,6 +814,30 @@ func parseCustomMenuItemURLs(raw string) []string {
return urls
}
func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.UsePKCEExplicit {
return base.UsePKCE
}
return true
}
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.ValidateIDTokenExplicit {
return base.ValidateIDToken
}
return true
}
func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool {
if configured {
return strings.TrimSpace(raw) == "true"
}
if explicit {
return explicitValue
}
return false
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
@ -770,6 +852,28 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
return err
}
func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, bool, error) {
rawSettings, err := s.settingRepo.GetMultiple(ctx, []string{
SettingKeyOIDCConnectUsePKCE,
SettingKeyOIDCConnectValidateIDToken,
})
if err != nil {
return false, false, fmt.Errorf("get oidc security write defaults: %w", err)
}
base := config.OIDCConnectConfig{}
if s != nil && s.cfg != nil {
base = s.cfg.OIDC
}
rawUsePKCE, hasUsePKCE := rawSettings[SettingKeyOIDCConnectUsePKCE]
rawValidateIDToken, hasValidateIDToken := rawSettings[SettingKeyOIDCConnectValidateIDToken]
return oidcCompatibilityWriteDefault(base, hasUsePKCE, rawUsePKCE, base.UsePKCEExplicit, base.UsePKCE),
oidcCompatibilityWriteDefault(base, hasValidateIDToken, rawValidateIDToken, base.ValidateIDTokenExplicit, base.ValidateIDToken),
nil
}
// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write.
func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error {
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
@ -1421,6 +1525,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
return fmt.Errorf("check existing settings: %w", err)
}
oidcUsePKCEDefault := true
oidcValidateIDTokenDefault := true
if s != nil && s.cfg != nil {
if s.cfg.OIDC.UsePKCEExplicit {
oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
}
if s.cfg.OIDC.ValidateIDTokenExplicit {
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
}
}
// 初始化默认设置
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
@ -1436,6 +1551,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyWeChatConnectEnabled: "false",
SettingKeyWeChatConnectAppID: "",
SettingKeyWeChatConnectAppSecret: "",
SettingKeyWeChatConnectOpenAppID: "",
SettingKeyWeChatConnectOpenAppSecret: "",
SettingKeyWeChatConnectMPAppID: "",
@ -1447,9 +1564,30 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyWeChatConnectMobileEnabled: "false",
SettingKeyWeChatConnectMode: "open",
SettingKeyWeChatConnectScopes: "snsapi_login",
SettingKeyWeChatConnectRedirectURL: "",
SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyOIDCConnectClientID: "",
SettingKeyOIDCConnectClientSecret: "",
SettingKeyOIDCConnectIssuerURL: "",
SettingKeyOIDCConnectDiscoveryURL: "",
SettingKeyOIDCConnectAuthorizeURL: "",
SettingKeyOIDCConnectTokenURL: "",
SettingKeyOIDCConnectUserInfoURL: "",
SettingKeyOIDCConnectJWKSURL: "",
SettingKeyOIDCConnectScopes: "openid email profile",
SettingKeyOIDCConnectRedirectURL: "",
SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
SettingKeyOIDCConnectClockSkewSeconds: "120",
SettingKeyOIDCConnectRequireEmailVerified: "false",
SettingKeyOIDCConnectUserInfoEmailPath: "",
SettingKeyOIDCConnectUserInfoIDPath: "",
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
@ -1686,15 +1824,13 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
}
result.OIDCConnectUsePKCE = true
result.OIDCConnectValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
@ -1739,37 +1875,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
// WeChat Connect 设置:完全以 DB 系统设置为准。
result.WeChatConnectEnabled = settings[SettingKeyWeChatConnectEnabled] == "true"
result.WeChatConnectAppID = strings.TrimSpace(settings[SettingKeyWeChatConnectAppID])
result.WeChatConnectAppSecret = strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret])
result.WeChatConnectAppSecretConfigured = result.WeChatConnectAppSecret != ""
result.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], result.WeChatConnectAppID))
result.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], result.WeChatConnectAppSecret))
result.WeChatConnectOpenAppSecretConfigured = result.WeChatConnectOpenAppSecret != ""
result.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], result.WeChatConnectAppID))
result.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], result.WeChatConnectAppSecret))
result.WeChatConnectMPAppSecretConfigured = result.WeChatConnectMPAppSecret != ""
result.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], result.WeChatConnectAppID))
result.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], result.WeChatConnectAppSecret))
result.WeChatConnectMobileAppSecretConfigured = result.WeChatConnectMobileAppSecret != ""
result.WeChatConnectOpenEnabled, result.WeChatConnectMPEnabled, result.WeChatConnectMobileEnabled = parseWeChatConnectCapabilitySettings(
settings,
result.WeChatConnectEnabled,
settings[SettingKeyWeChatConnectMode],
)
result.WeChatConnectMode = normalizeWeChatConnectStoredMode(
result.WeChatConnectOpenEnabled,
result.WeChatConnectMPEnabled,
result.WeChatConnectMobileEnabled,
settings[SettingKeyWeChatConnectMode],
)
result.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], result.WeChatConnectMode)
result.WeChatConnectRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL])
result.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL])
if result.WeChatConnectFrontendRedirectURL == "" {
result.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
}
// WeChat Connect 设置:
// - 优先读取 DB 系统设置
// - 缺失时回退到 config/env保持升级兼容
weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings)
result.WeChatConnectEnabled = weChatEffective.Enabled
result.WeChatConnectAppID = weChatEffective.LegacyAppID
result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret
result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != ""
result.WeChatConnectOpenAppID = weChatEffective.OpenAppID
result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret
result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != ""
result.WeChatConnectMPAppID = weChatEffective.MPAppID
result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret
result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != ""
result.WeChatConnectMobileAppID = weChatEffective.MobileAppID
result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret
result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != ""
result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled
result.WeChatConnectMPEnabled = weChatEffective.MPEnabled
result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled
result.WeChatConnectMode = weChatEffective.Mode
result.WeChatConnectScopes = weChatEffective.Scopes
result.WeChatConnectRedirectURL = weChatEffective.RedirectURL
result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
@ -1861,14 +1990,9 @@ func isFalseSettingValue(value string) bool {
}
func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) {
_ = enabled
source = strings.TrimSpace(source)
if source == "" {
if enabled {
return "", infraerrors.BadRequest(
"INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
fmt.Sprintf("%s source is required when the visible method is enabled", method),
)
}
return "", nil
}
@ -2196,8 +2320,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
effective.UsePKCE = true
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
@ -2417,12 +2539,14 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
} else {
effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
} else {
effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
}
effective.UsePKCE = true
effective.ValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}

View File

@ -101,3 +101,151 @@ func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testi
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL)
}
func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{})
got := svc.parseSettings(map[string]string{
SettingKeyOIDCConnectEnabled: "true",
SettingKeyOIDCConnectUsePKCE: "false",
SettingKeyOIDCConnectValidateIDToken: "false",
})
require.False(t, got.OIDCConnectUsePKCE)
require.False(t, got.OIDCConnectValidateIDToken)
}
func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{
UsePKCE: true,
UsePKCEExplicit: true,
ValidateIDToken: true,
ValidateIDTokenExplicit: true,
},
})
got := svc.parseSettings(map[string]string{
SettingKeyOIDCConnectEnabled: "true",
})
require.True(t, got.OIDCConnectUsePKCE)
require.True(t, got.OIDCConnectValidateIDToken)
}
func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{
UsePKCE: true,
ValidateIDToken: true,
},
})
got := svc.parseSettings(map[string]string{
SettingKeyOIDCConnectEnabled: "true",
})
require.True(t, got.OIDCConnectUsePKCE)
require.True(t, got.OIDCConnectValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
},
}
repo := &settingOIDCRepoStub{values: map[string]string{
SettingKeyOIDCConnectEnabled: "true",
SettingKeyOIDCConnectUsePKCE: "false",
SettingKeyOIDCConnectValidateIDToken: "false",
}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.False(t, got.UsePKCE)
require.False(t, got.ValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
UsePKCEExplicit: true,
ValidateIDToken: true,
ValidateIDTokenExplicit: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{
SettingKeyOIDCConnectEnabled: "true",
}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.True(t, got.UsePKCE)
require.True(t, got.ValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{
SettingKeyOIDCConnectEnabled: "true",
}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.True(t, got.UsePKCE)
require.True(t, got.ValidateIDToken)
}

View File

@ -112,3 +112,42 @@ func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *
require.True(t, settings.WeChatOAuthOpenEnabled)
require.True(t, settings.WeChatOAuthMPEnabled)
}
func TestSettingService_GetPublicSettings_DoesNotExposeMobileOnlyWeChatAsWebOAuthAvailable(t *testing.T) {
svc := NewSettingService(&settingPublicRepoStub{
values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectMobileEnabled: "true",
SettingKeyWeChatConnectMode: "mobile",
SettingKeyWeChatConnectMobileAppID: "wx-mobile-app",
SettingKeyWeChatConnectMobileAppSecret: "wx-mobile-secret",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
},
}, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.False(t, settings.WeChatOAuthEnabled)
require.False(t, settings.WeChatOAuthOpenEnabled)
require.False(t, settings.WeChatOAuthMPEnabled)
require.True(t, settings.WeChatOAuthMobileEnabled)
}
func TestSettingService_GetPublicSettings_FallsBackToConfigForWeChatOAuthCapabilities(t *testing.T) {
svc := NewSettingService(&settingPublicRepoStub{values: map[string]string{}}, &config.Config{
WeChat: config.WeChatConnectConfig{
Enabled: true,
OpenEnabled: true,
OpenAppID: "wx-open-config",
OpenAppSecret: "wx-open-secret",
FrontendRedirectURL: "/auth/wechat/config-callback",
},
})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.True(t, settings.WeChatOAuthEnabled)
require.True(t, settings.WeChatOAuthOpenEnabled)
require.False(t, settings.WeChatOAuthMPEnabled)
require.False(t, settings.WeChatOAuthMobileEnabled)
}

View File

@ -79,3 +79,84 @@ func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *tes
require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL)
require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL)
}
func TestSettingService_GetWeChatConnectOAuthConfig_FallsBackToConfigWhenDatabaseEmpty(t *testing.T) {
repo := &settingWeChatRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{
WeChat: config.WeChatConnectConfig{
Enabled: true,
OpenEnabled: true,
MPEnabled: true,
Mode: "open",
OpenAppID: "wx-open-config",
OpenAppSecret: "wx-open-secret",
MPAppID: "wx-mp-config",
MPAppSecret: "wx-mp-secret",
FrontendRedirectURL: "/auth/wechat/config-callback",
},
})
got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.True(t, got.Enabled)
require.True(t, got.OpenEnabled)
require.True(t, got.MPEnabled)
require.Equal(t, "wx-open-config", got.AppIDForMode("open"))
require.Equal(t, "wx-open-secret", got.AppSecretForMode("open"))
require.Equal(t, "wx-mp-config", got.AppIDForMode("mp"))
require.Equal(t, "wx-mp-secret", got.AppSecretForMode("mp"))
require.Equal(t, "/auth/wechat/config-callback", got.FrontendRedirectURL)
require.Empty(t, got.RedirectURL)
}
func TestSettingService_GetWeChatConnectOAuthConfig_IgnoresSyntheticDisabledCapabilitiesFromMigration118(t *testing.T) {
repo := &settingWeChatRepoStub{
values: map[string]string{
SettingKeyWeChatConnectOpenEnabled: "false",
SettingKeyWeChatConnectMPEnabled: "false",
},
}
svc := NewSettingService(repo, &config.Config{
WeChat: config.WeChatConnectConfig{
Enabled: true,
OpenEnabled: true,
MPEnabled: true,
Mode: "open",
OpenAppID: "wx-open-config",
OpenAppSecret: "wx-open-secret",
MPAppID: "wx-mp-config",
MPAppSecret: "wx-mp-secret",
FrontendRedirectURL: "/auth/wechat/config-callback",
},
})
got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.True(t, got.Enabled)
require.True(t, got.OpenEnabled)
require.True(t, got.MPEnabled)
require.Equal(t, "wx-open-config", got.AppIDForMode("open"))
require.Equal(t, "wx-mp-config", got.AppIDForMode("mp"))
}
func TestSettingService_ParseSettings_FallsBackToConfigForWeChatAdminView(t *testing.T) {
svc := NewSettingService(&settingWeChatRepoStub{values: map[string]string{}}, &config.Config{
WeChat: config.WeChatConnectConfig{
Enabled: true,
OpenEnabled: true,
Mode: "open",
OpenAppID: "wx-open-config",
OpenAppSecret: "wx-open-secret",
FrontendRedirectURL: "/auth/wechat/config-callback",
},
})
got := svc.parseSettings(map[string]string{})
require.True(t, got.WeChatConnectEnabled)
require.True(t, got.WeChatConnectOpenEnabled)
require.Equal(t, "wx-open-config", got.WeChatConnectOpenAppID)
require.True(t, got.WeChatConnectOpenAppSecretConfigured)
require.Equal(t, "/auth/wechat/config-callback", got.WeChatConnectFrontendRedirectURL)
require.Equal(t, "open", got.WeChatConnectMode)
require.Equal(t, "snsapi_login", got.WeChatConnectScopes)
}

View File

@ -23,12 +23,15 @@ type User struct {
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
// TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived
// value expected in JWT claims and refresh-token state.
TokenVersionResolved bool
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier

View File

@ -127,6 +127,7 @@ type UserIdentitySummary struct {
Bound bool `json:"bound"`
BoundCount int `json:"bound_count"`
DisplayName string `json:"display_name,omitempty"`
AvatarURL string `json:"-"`
SubjectHint string `json:"subject_hint,omitempty"`
ProviderKey string `json:"provider_key,omitempty"`
VerifiedAt *time.Time `json:"verified_at,omitempty"`
@ -228,6 +229,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
normalizeLoadedUserTokenVersion(user)
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
@ -248,12 +250,59 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
return UserIdentitySummarySet{}, err
}
return UserIdentitySummarySet{
summaries := UserIdentitySummarySet{
Email: s.buildEmailIdentitySummary(user, records),
LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
OIDC: s.buildProviderIdentitySummary("oidc", user, records),
WeChat: s.buildProviderIdentitySummary("wechat", user, records),
}, nil
}
s.applyExplicitProviderAvailability(ctx, &summaries)
return summaries, nil
}
func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, summaries *UserIdentitySummarySet) {
if s == nil || summaries == nil || s.settingRepo == nil {
return
}
settings, err := s.settingRepo.GetMultiple(ctx, []string{
SettingKeyLinuxDoConnectEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyWeChatConnectEnabled,
SettingKeyWeChatConnectOpenEnabled,
SettingKeyWeChatConnectMPEnabled,
SettingKeyWeChatConnectMobileEnabled,
SettingKeyWeChatConnectMode,
})
if err != nil {
return
}
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
disableIdentityBindAction(&summaries.LinuxDo)
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
disableIdentityBindAction(&summaries.OIDC)
}
if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok && strings.TrimSpace(raw) != "" {
if raw != "true" {
disableIdentityBindAction(&summaries.WeChat)
return
}
openEnabled, mpEnabled, _ := parseWeChatConnectCapabilitySettings(settings, true, settings[SettingKeyWeChatConnectMode])
if !openEnabled && !mpEnabled {
disableIdentityBindAction(&summaries.WeChat)
}
}
}
func disableIdentityBindAction(summary *UserIdentitySummary) {
if summary == nil || summary.Bound {
return
}
summary.CanBind = false
summary.BindStartPath = ""
}
func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) {
@ -276,29 +325,34 @@ func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUs
}
func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) {
user, _, err := s.UnbindUserAuthProviderWithResult(ctx, userID, provider)
return user, err
}
func (s *UserService) UnbindUserAuthProviderWithResult(ctx context.Context, userID int64, provider string) (*User, bool, error) {
provider = normalizeUserIdentityProvider(provider)
if provider == "" || provider == "email" {
return nil, ErrIdentityProviderInvalid
return nil, false, ErrIdentityProviderInvalid
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
return nil, false, fmt.Errorf("get user: %w", err)
}
records, err := s.listUserAuthIdentities(ctx, userID)
if err != nil {
return nil, err
return nil, false, err
}
if len(filterUserAuthIdentities(records, provider)) == 0 {
return user, nil
return user, false, nil
}
if !s.canUnbindProvider(provider, user, records) {
return nil, ErrIdentityUnbindLastMethod
return nil, false, ErrIdentityUnbindLastMethod
}
if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil {
return nil, err
return nil, false, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
@ -306,9 +360,9 @@ func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64,
updatedUser, err := s.GetProfile(ctx, userID)
if err != nil {
return nil, err
return nil, false, err
}
return updatedUser, nil
return updatedUser, true, nil
}
// UpdateProfile 更新用户资料
@ -608,6 +662,7 @@ func (s *UserService) buildProviderIdentitySummary(provider string, user *User,
summary.Bound = true
summary.BoundCount = len(filtered)
summary.DisplayName = userAuthIdentityDisplayName(primary)
summary.AvatarURL = strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "avatar_url", "suggested_avatar_url", "headimgurl"))
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
summary.VerifiedAt = primary.VerifiedAt
@ -625,7 +680,7 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U
return false
}
if s.buildEmailIdentitySummary(user, records).Bound {
if s.canUseEmailAsSignInMethod(user, records) {
return true
}
@ -641,6 +696,44 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U
return false
}
func (s *UserService) canUseEmailAsSignInMethod(user *User, records []UserAuthIdentityRecord) bool {
if user == nil {
return false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return false
}
if emailSignupSourceAllowsLogin(user.SignupSource) {
return true
}
for _, record := range filterUserAuthIdentities(records, "email") {
if emailIdentitySupportsSignIn(record) {
return true
}
}
return false
}
func emailSignupSourceAllowsLogin(signupSource string) bool {
signupSource = strings.ToLower(strings.TrimSpace(signupSource))
return signupSource == "" || signupSource == "email"
}
func emailIdentitySupportsSignIn(record UserAuthIdentityRecord) bool {
source := strings.TrimSpace(firstStringIdentityValue(record.Metadata, "source"))
switch source {
case "auth_service_email_bind", "auth_service_login_backfill", "auth_service_dual_write":
return true
default:
return false
}
}
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
if userID <= 0 || s == nil || s.userRepo == nil {
return nil, nil
@ -662,11 +755,11 @@ func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, err
path := ""
switch provider {
case "linuxdo":
path = "/api/v1/auth/oauth/linuxdo/start"
path = "/api/v1/auth/oauth/linuxdo/bind/start"
case "oidc":
path = "/api/v1/auth/oauth/oidc/start"
path = "/api/v1/auth/oauth/oidc/bind/start"
case "wechat":
path = "/api/v1/auth/oauth/wechat/start"
path = "/api/v1/auth/oauth/wechat/bind/start"
default:
return "", ErrIdentityProviderInvalid
}
@ -842,12 +935,21 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
normalizeLoadedUserTokenVersion(user)
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil
}
func normalizeLoadedUserTokenVersion(user *User) {
if user == nil || user.TokenVersionResolved {
return
}
user.TokenVersion = resolvedTokenVersion(user)
user.TokenVersionResolved = true
}
// TouchLastActive 通过防抖更新 users.last_active_at减少鉴权热路径写放大。
// 该操作为尽力而为,不应中断正常请求。
func (s *UserService) TouchLastActive(ctx context.Context, userID int64) {

View File

@ -51,6 +51,44 @@ type mockUserRepoTxState struct {
deleteAvatarIDs []int64
}
type mockUserSettingRepo struct {
values map[string]string
}
func (m *mockUserSettingRepo) Get(context.Context, string) (*Setting, error) {
panic("unexpected Get call")
}
func (m *mockUserSettingRepo) GetValue(context.Context, string) (string, error) {
panic("unexpected GetValue call")
}
func (m *mockUserSettingRepo) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (m *mockUserSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := m.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (m *mockUserSettingRepo) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (m *mockUserSettingRepo) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (m *mockUserSettingRepo) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) {
if m.getByIDErr != nil {
@ -349,6 +387,70 @@ func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) {
require.Empty(t, repo.unboundProviders)
}
func TestGetProfileIdentitySummaries_DoesNotTreatOAuthOnlyCompatEmailAsAlternativeLoginMethod(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 10,
Email: "oauth-only@example.com",
SignupSource: "oidc",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "oidc",
ProviderKey: "https://issuer.example.com",
ProviderSubject: "oidc-only-subject",
},
},
}
svc := NewUserService(repo, nil, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 10, repo.getByIDUser)
require.NoError(t, err)
require.False(t, summaries.OIDC.CanUnbind)
_, err = svc.UnbindUserAuthProvider(context.Background(), 10, "oidc")
require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
require.Empty(t, repo.unboundProviders)
}
func TestGetProfileIdentitySummaries_DoesNotTreatCompatBackfilledEmailIdentityAsAlternativeLoginMethod(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 11,
Email: "oauth-only@example.com",
SignupSource: "wechat",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "oauth-only@example.com",
Metadata: map[string]any{
"backfill_source": "users.email",
"migration": "109_auth_identity_compat_backfill",
},
},
{
ProviderType: "wechat",
ProviderKey: "wechat",
ProviderSubject: "wechat-only-subject",
},
},
}
svc := NewUserService(repo, nil, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 11, repo.getByIDUser)
require.NoError(t, err)
require.True(t, summaries.Email.Bound)
require.False(t, summaries.WeChat.CanUnbind)
_, err = svc.UnbindUserAuthProvider(context.Background(), 11, "wechat")
require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
require.Empty(t, repo.unboundProviders)
}
func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
@ -368,13 +470,15 @@ func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testin
},
},
}
svc := NewUserService(repo, nil, nil, nil)
invalidator := &mockAuthCacheInvalidator{}
svc := NewUserService(repo, nil, invalidator, nil)
user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo")
require.NoError(t, err)
require.Equal(t, []string{"linuxdo"}, repo.unboundProviders)
require.Equal(t, int64(12), user.ID)
require.Equal(t, []int64{12}, invalidator.invalidatedUserIDs)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user)
require.NoError(t, err)
@ -382,6 +486,71 @@ func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testin
require.True(t, summaries.LinuxDo.CanBind)
}
func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabled(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 15,
Email: "alice@example.com",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "alice@example.com",
},
},
}
settingRepo := &mockUserSettingRepo{
values: map[string]string{
SettingKeyLinuxDoConnectEnabled: "false",
},
}
svc := NewUserService(repo, settingRepo, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 15, repo.getByIDUser)
require.NoError(t, err)
require.False(t, summaries.LinuxDo.Bound)
require.False(t, summaries.LinuxDo.CanBind)
require.Empty(t, summaries.LinuxDo.BindStartPath)
}
func TestGetProfileIdentitySummaries_UsesBindStartRoute(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 16,
Email: "alice@example.com",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "alice@example.com",
},
},
}
svc := NewUserService(repo, nil, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 16, repo.getByIDUser)
require.NoError(t, err)
require.Equal(
t,
"/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
summaries.LinuxDo.BindStartPath,
)
require.Equal(
t,
"/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
summaries.OIDC.BindStartPath,
)
require.Equal(
t,
"/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
summaries.WeChat.BindStartPath,
)
}
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil

View File

@ -0,0 +1,14 @@
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
AND COALESCE(character_maximum_length, 0) < 80
) THEN
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(80);
END IF;
END $$;

View File

@ -1,6 +1,3 @@
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(80);
INSERT INTO auth_identities (
user_id,
provider_type,

View File

@ -38,23 +38,22 @@ VALUES
('auth_source_default_email_balance', '0'),
('auth_source_default_email_concurrency', '5'),
('auth_source_default_email_subscriptions', '[]'),
('auth_source_default_email_grant_on_signup', 'true'),
('auth_source_default_email_grant_on_signup', 'false'),
('auth_source_default_email_grant_on_first_bind', 'false'),
('auth_source_default_linuxdo_balance', '0'),
('auth_source_default_linuxdo_concurrency', '5'),
('auth_source_default_linuxdo_subscriptions', '[]'),
('auth_source_default_linuxdo_grant_on_signup', 'true'),
('auth_source_default_linuxdo_grant_on_signup', 'false'),
('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
('auth_source_default_oidc_balance', '0'),
('auth_source_default_oidc_concurrency', '5'),
('auth_source_default_oidc_subscriptions', '[]'),
('auth_source_default_oidc_grant_on_signup', 'true'),
('auth_source_default_oidc_grant_on_signup', 'false'),
('auth_source_default_oidc_grant_on_first_bind', 'false'),
('auth_source_default_wechat_balance', '0'),
('auth_source_default_wechat_concurrency', '5'),
('auth_source_default_wechat_subscriptions', '[]'),
('auth_source_default_wechat_grant_on_signup', 'true'),
('auth_source_default_wechat_grant_on_signup', 'false'),
('auth_source_default_wechat_grant_on_first_bind', 'false'),
('force_email_on_third_party_signup', 'false')
ON CONFLICT (key) DO NOTHING;

View File

@ -1,4 +1,4 @@
ALTER TABLE payment_orders ADD COLUMN provider_key VARCHAR(30);
ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30);
UPDATE payment_orders
SET provider_key = (

View File

@ -31,6 +31,41 @@ BEGIN
END IF;
EXECUTE $sql$
WITH legacy AS (
SELECT
uei.id,
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
),
legacy_subjects AS (
SELECT
provider_user_id AS provider_subject,
COUNT(DISTINCT user_id) AS distinct_user_count
FROM legacy
GROUP BY provider_user_id
),
canonical_legacy AS (
SELECT
legacy.*,
ROW_NUMBER() OVER (
PARTITION BY legacy.provider_user_id
ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
) AS canonical_row_num
FROM legacy
JOIN legacy_subjects AS subjects
ON subjects.provider_subject = legacy.provider_user_id
AND subjects.distinct_user_count = 1
)
INSERT INTO auth_identities (
user_id,
provider_type,
@ -52,11 +87,18 @@ SELECT
'display_name', legacy.display_name,
'migration', '115_auth_identity_legacy_external_backfill'
)
FROM (
FROM canonical_legacy AS legacy
WHERE legacy.canonical_row_num = 1
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$;
EXECUTE $sql$
WITH legacy AS (
SELECT
uei.id,
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
@ -65,13 +107,28 @@ FROM (
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
) AS legacy
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$;
EXECUTE $sql$
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
),
legacy_subjects AS (
SELECT
provider_union_id AS provider_subject,
COUNT(DISTINCT user_id) AS distinct_user_count
FROM legacy
GROUP BY provider_union_id
),
canonical_legacy AS (
SELECT
legacy.*,
ROW_NUMBER() OVER (
PARTITION BY legacy.provider_union_id
ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
) AS canonical_row_num
FROM legacy
JOIN legacy_subjects AS subjects
ON subjects.provider_subject = legacy.provider_union_id
AND subjects.distinct_user_count = 1
)
INSERT INTO auth_identities (
user_id,
provider_type,
@ -96,27 +153,36 @@ SELECT
'display_name', legacy.display_name,
'migration', '115_auth_identity_legacy_external_backfill'
)
FROM (
SELECT
uei.id,
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
) AS legacy
FROM canonical_legacy AS legacy
WHERE legacy.canonical_row_num = 1
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$;
EXECUTE $sql$
WITH legacy AS (
SELECT
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
meta.metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
CROSS JOIN LATERAL (
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
) AS meta
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
),
legacy_subjects AS (
SELECT
provider_union_id AS provider_subject,
COUNT(DISTINCT user_id) AS distinct_user_count
FROM legacy
GROUP BY provider_union_id
)
INSERT INTO auth_identity_channels (
identity_id,
provider_type,
@ -138,23 +204,10 @@ SELECT
'unionid', legacy.provider_union_id,
'migration', '115_auth_identity_legacy_external_backfill'
)
FROM (
SELECT
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
meta.metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
CROSS JOIN LATERAL (
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
) AS meta
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
) AS legacy
FROM legacy
JOIN legacy_subjects AS subjects
ON subjects.provider_subject = legacy.provider_union_id
AND subjects.distinct_user_count = 1
JOIN auth_identities AS ai
ON ai.user_id = legacy.user_id
AND ai.provider_type = 'wechat'

View File

@ -74,6 +74,82 @@ $sql$;
EXECUTE $sql$
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'legacy_external_identity_conflict',
'legacy_external_identity:' || legacy.id::text,
legacy.metadata_json || jsonb_build_object(
'legacy_identity_id', legacy.id,
'legacy_user_id', legacy.user_id,
'provider_type', legacy.provider_type,
'provider_key', legacy.provider_key,
'provider_subject', legacy.provider_subject,
'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids,
'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved',
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM (
SELECT
uei.id,
uei.user_id,
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
ELSE 'linuxdo'
END AS provider_key,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
END AS provider_subject,
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
AND (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy
JOIN (
SELECT
provider_type,
provider_key,
provider_subject,
to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids
FROM (
SELECT
uei.user_id,
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
ELSE 'linuxdo'
END AS provider_key,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
END AS provider_subject
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
AND (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy_subjects
GROUP BY provider_type, provider_key, provider_subject
HAVING COUNT(DISTINCT user_id) > 1
) AS ambiguous
ON ambiguous.provider_type = legacy.provider_type
AND ambiguous.provider_key = legacy.provider_key
AND ambiguous.provider_subject = legacy.provider_subject
ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$;
EXECUTE $sql$
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'legacy_external_identity_conflict',
'legacy_external_identity:' || legacy.id::text,
@ -116,6 +192,39 @@ FROM (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy
JOIN (
SELECT
provider_type,
provider_key,
provider_subject
FROM (
SELECT
uei.user_id,
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
ELSE 'linuxdo'
END AS provider_key,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
END AS provider_subject
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
AND (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy_subjects
GROUP BY provider_type, provider_key, provider_subject
HAVING COUNT(DISTINCT user_id) = 1
) AS clear_subjects
ON clear_subjects.provider_type = legacy.provider_type
AND clear_subjects.provider_key = legacy.provider_key
AND clear_subjects.provider_subject = legacy.provider_subject
JOIN auth_identities AS ai
ON ai.provider_type = legacy.provider_type
AND ai.provider_key = legacy.provider_key
@ -125,29 +234,7 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$;
EXECUTE $sql$
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
legacy.user_id,
legacy.provider_type,
legacy.provider_key,
legacy.provider_subject,
legacy.verified_at,
legacy.metadata_json || jsonb_build_object(
'legacy_identity_id', legacy.id,
'provider_user_id', legacy.provider_user_id,
'provider_union_id', NULLIF(legacy.provider_union_id, ''),
'provider_username', legacy.provider_username,
'display_name', legacy.display_name,
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM (
WITH legacy AS (
SELECT
uei.id,
uei.user_id,
@ -175,12 +262,58 @@ FROM (
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy
),
clear_subjects AS (
SELECT
provider_type,
provider_key,
provider_subject
FROM legacy
GROUP BY provider_type, provider_key, provider_subject
HAVING COUNT(DISTINCT user_id) = 1
),
canonical_legacy AS (
SELECT
legacy.*,
ROW_NUMBER() OVER (
PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject
ORDER BY legacy.verified_at DESC, legacy.id DESC
) AS canonical_row_num
FROM legacy
JOIN clear_subjects
ON clear_subjects.provider_type = legacy.provider_type
AND clear_subjects.provider_key = legacy.provider_key
AND clear_subjects.provider_subject = legacy.provider_subject
)
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
legacy.user_id,
legacy.provider_type,
legacy.provider_key,
legacy.provider_subject,
legacy.verified_at,
legacy.metadata_json || jsonb_build_object(
'legacy_identity_id', legacy.id,
'provider_user_id', legacy.provider_user_id,
'provider_union_id', NULLIF(legacy.provider_union_id, ''),
'provider_username', legacy.provider_username,
'display_name', legacy.display_name,
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM canonical_legacy AS legacy
LEFT JOIN auth_identities AS ai
ON ai.provider_type = legacy.provider_type
AND ai.provider_key = legacy.provider_key
AND ai.provider_subject = legacy.provider_subject
WHERE ai.id IS NULL
WHERE legacy.canonical_row_num = 1
AND ai.id IS NULL
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$;
@ -225,6 +358,19 @@ FROM (
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
) AS legacy
JOIN (
SELECT
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
GROUP BY BTRIM(COALESCE(uei.provider_union_id, ''))
HAVING COUNT(DISTINCT uei.user_id) = 1
) AS clear_subjects
ON clear_subjects.provider_subject = legacy.provider_union_id
JOIN auth_identities AS legacy_ai
ON legacy_ai.user_id = legacy.user_id
AND legacy_ai.provider_type = 'wechat'
@ -245,6 +391,33 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$;
EXECUTE $sql$
WITH legacy AS (
SELECT
uei.user_id,
BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
BTRIM(COALESCE(
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
''
)) AS channel_app_id
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
),
clear_subjects AS (
SELECT
provider_union_id AS provider_subject
FROM legacy
GROUP BY provider_union_id
HAVING COUNT(DISTINCT user_id) = 1
)
INSERT INTO auth_identity_channels (
identity_id,
provider_type,
@ -266,26 +439,9 @@ SELECT
'unionid', legacy.provider_union_id,
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM (
SELECT
uei.user_id,
BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
BTRIM(COALESCE(
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
''
)) AS channel_app_id
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
) AS legacy
FROM legacy
JOIN clear_subjects
ON clear_subjects.provider_subject = legacy.provider_union_id
JOIN auth_identities AS legacy_ai
ON legacy_ai.user_id = legacy.user_id
AND legacy_ai.provider_type = 'wechat'

View File

@ -3,6 +3,7 @@ VALUES
(
'wechat_connect_open_enabled',
CASE
WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false'
ELSE 'true'
@ -11,6 +12,7 @@ VALUES
(
'wechat_connect_mp_enabled',
CASE
WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true'
ELSE 'false'
@ -21,12 +23,3 @@ VALUES
('auth_source_default_oidc_grant_on_signup', 'false'),
('auth_source_default_wechat_grant_on_signup', 'false')
ON CONFLICT (key) DO NOTHING;
UPDATE settings
SET value = 'false'
WHERE key IN (
'auth_source_default_email_grant_on_signup',
'auth_source_default_linuxdo_grant_on_signup',
'auth_source_default_oidc_grant_on_signup',
'auth_source_default_wechat_grant_on_signup'
);

View File

@ -0,0 +1,6 @@
-- Intentionally left as a no-op.
-- The online index rollout lives in 120_enforce_payment_orders_out_trade_no_unique_notx.sql
DO $$
BEGIN
NULL;
END $$;

View File

@ -0,0 +1,10 @@
-- Build the payment order uniqueness guarantee online.
-- The migration runner performs an explicit duplicate out_trade_no precheck and
-- drops any stale invalid paymentorder_out_trade_no_unique index before retrying.
-- Create the new partial unique index concurrently first so writes keep flowing,
-- then remove the legacy index name once the replacement is ready.
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;

View File

@ -0,0 +1,22 @@
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = 'payment_orders'
AND indexname = 'paymentorder_out_trade_no_unique'
) THEN
IF EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = 'payment_orders'
AND indexname = 'paymentorder_out_trade_no'
) THEN
EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no';
END IF;
EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no';
END IF;
END $$;

View File

@ -0,0 +1,2 @@
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(80);

View File

@ -0,0 +1,15 @@
UPDATE pending_auth_sessions
SET
local_flow_state = jsonb_set(
local_flow_state,
'{completion_response}',
((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'),
true
)
WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object'
AND (
(local_flow_state -> 'completion_response') ? 'access_token'
OR (local_flow_state -> 'completion_response') ? 'refresh_token'
OR (local_flow_state -> 'completion_response') ? 'expires_in'
OR (local_flow_state -> 'completion_response') ? 'token_type'
);

View File

@ -0,0 +1,68 @@
-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value.
-- Rows still matching the migration-110 default payload and timestamp window are treated as
-- untouched legacy defaults; any remaining legacy true values are reported for manual review.
WITH migration_110 AS (
SELECT applied_at
FROM schema_migrations
WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
),
providers AS (
SELECT provider_type
FROM (
VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
) AS providers(provider_type)
),
legacy_provider_defaults AS (
SELECT providers.provider_type
FROM providers
CROSS JOIN migration_110
JOIN settings balance
ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
JOIN settings concurrency
ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
JOIN settings subscriptions
ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
JOIN settings grant_on_signup
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
JOIN settings grant_on_first_bind
ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
WHERE balance.value = '0'
AND concurrency.value = '5'
AND subscriptions.value = '[]'
AND grant_on_signup.value = 'true'
AND grant_on_first_bind.value = 'false'
AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
),
updated_signup_grants AS (
UPDATE settings
SET
value = 'false',
updated_at = NOW()
FROM legacy_provider_defaults
WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
AND settings.value = 'true'
RETURNING legacy_provider_defaults.provider_type
)
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'legacy_auth_source_signup_grant_review',
providers.provider_type,
jsonb_build_object(
'provider_type', providers.provider_type,
'current_value', grant_on_signup.value,
'auto_backfilled', FALSE,
'reason', 'legacy_true_default_not_auto_backfilled'
)
FROM providers
JOIN settings grant_on_signup
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
LEFT JOIN updated_signup_grants
ON updated_signup_grants.provider_type = providers.provider_type
WHERE grant_on_signup.value = 'true'
AND updated_signup_grants.provider_type IS NULL
ON CONFLICT (report_type, report_key) DO NOTHING;

Some files were not shown because too many files have changed in this diff Show More