diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index d7e15377..f8b22ee7 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -28,6 +28,26 @@ jobs:
working-directory: backend
run: make test-integration
+ frontend:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ - name: Setup pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ version: 9
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: '20'
+ cache: 'pnpm'
+ cache-dependency-path: frontend/pnpm-lock.yaml
+ - name: Install frontend dependencies
+ working-directory: frontend
+ run: pnpm install --frozen-lockfile
+ - name: Frontend typecheck and critical vitest
+ run: make test-frontend
+
golangci-lint:
runs-on: ubuntu-latest
steps:
@@ -46,4 +66,4 @@ jobs:
with:
version: v2.9
args: --timeout=30m
- working-directory: backend
\ No newline at end of file
+ working-directory: backend
diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml
new file mode 100644
index 00000000..67c8d6e9
--- /dev/null
+++ b/.github/workflows/cla.yml
@@ -0,0 +1,59 @@
+name: "CLA Assistant"
+
+on:
+ issue_comment:
+ types: [created]
+ pull_request_target:
+ types: [opened, reopened, closed, synchronize]
+
+permissions:
+ actions: write
+ contents: write
+ pull-requests: write
+ statuses: write
+
+jobs:
+ cla-check:
+ if: |
+ github.event_name == 'issue_comment' ||
+ (github.event_name == 'pull_request_target' && github.event.action != 'closed')
+ runs-on: ubuntu-latest
+ steps:
+ - name: "CLA Assistant"
+ if: |
+ (github.event.comment.body == 'recheck' ||
+ github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') ||
+ github.event_name == 'pull_request_target'
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ allowlist: "dependabot[bot],renovate[bot],bot*"
+ lock-pullrequest-aftermerge: false
+ custom-notsigned-prcomment: |
+ Thank you for your contribution! Before we can merge this PR, we need $you to sign our [Contributor License Agreement (CLA)](https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md).
+
+ **To sign**, please reply with the following comment:
+
+ > I have read the CLA Document and I hereby sign the CLA
+
+ You only need to sign once — it will be valid for all your future contributions to this project.
+ custom-pr-sign-comment: "I have read the CLA Document and I hereby sign the CLA"
+ custom-allsigned-prcomment: "All contributors have signed the CLA. ✅"
+
+ cla-lock:
+ if: github.event_name == 'pull_request_target' && github.event.action == 'closed' && github.event.pull_request.merged == true
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Lock merged PR"
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ lock-pullrequest-aftermerge: true
diff --git a/.gitignore b/.gitignore
index 1a92ea3e..bf7ee064 100644
--- a/.gitignore
+++ b/.gitignore
@@ -129,9 +129,9 @@ vite.config.js
docs/*
!docs/PAYMENT.md
!docs/PAYMENT_CN.md
+!docs/ADMIN_PAYMENT_INTEGRATION_API.md
.serena/
.codex/
frontend/coverage/
aicodex
output/
-
diff --git a/CLA.md b/CLA.md
new file mode 100644
index 00000000..ed0d74b8
--- /dev/null
+++ b/CLA.md
@@ -0,0 +1,73 @@
+# Sub2API Individual Contributor License Agreement (v1.0)
+
+Thank you for your interest in contributing to Sub2API ("the Project"). This Contributor License Agreement ("Agreement") documents the rights granted by contributors to the Project.
+
+By signing this Agreement, you accept and agree to the following terms and conditions for your present and future contributions submitted to the Project.
+
+## 1. Definitions
+
+- **"You" (or "Your")** means the copyright owner or legal entity authorized by the copyright owner that is making this Agreement.
+- **"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the Project for inclusion in, or documentation of, any of the products owned or managed by the Project. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Project or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of discussing and improving the Project, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
+- **"Project Owner"** means Wesley Liddick, or any individual or legal entity to whom Wesley Liddick has explicitly assigned or transferred ownership of the Project in writing, and their respective successors and assigns.
+
+## 2. Grant of Copyright License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. This license includes, without limitation, the right to sublicense, assign, and transfer these rights to any third party, including without limitation any successor, assignee, or acquiring entity of the Project or the Project Owner, and to use Your Contributions under any license, including proprietary or commercial licenses.
+
+## 3. Moral Rights
+
+To the fullest extent permitted by applicable law, You irrevocably waive and agree not to assert any moral rights (including rights of attribution and integrity) that You may have in Your Contributions, and agree that the Project Owner and its licensees may use, modify, and distribute Your Contributions without attribution or other obligations arising from moral rights.
+
+## 4. Grant of Patent License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Your Contributions, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Project to which such Contribution(s) was submitted.
+
+## 5. Representations and Warranties
+
+You represent and warrant that:
+
+(a) You are legally entitled to grant the above licenses.
+
+(b) If Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You have received permission to make Contributions on behalf of that employer, or that Your employer has waived such rights for Your Contributions to the Project.
+
+(c) Each of Your Contributions is Your original creation, or You have sufficient rights to submit it under the terms of this Agreement. You agree to provide, upon request, reasonable documentation or explanation of any third-party materials included in Your Contributions.
+
+## 6. No Warranty
+
+Your Contributions are provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
+
+## 7. No Obligation
+
+You understand that the decision to include Your Contribution in any product or project is entirely at the discretion of the Project Owner, and this Agreement does not obligate the Project Owner to use Your Contribution.
+
+## 8. Retention of Rights
+
+You retain ownership of the copyright in Your Contributions. This Agreement does not transfer any copyright or other intellectual property rights from You to the Project Owner. This Agreement only grants the licenses described above.
+
+## 9. Term and Termination
+
+This Agreement shall remain in effect indefinitely. You may terminate this Agreement prospectively by providing written notice to the Project Owner, but such termination shall not affect the licenses granted for Contributions submitted prior to the effective date of termination. The licenses granted herein for Contributions submitted prior to termination are perpetual and irrevocable.
+
+## 10. Electronic Signature
+
+You agree that Your electronic signature (including but not limited to typing a specific phrase in a pull request, issue, or other electronic communication) is legally binding and has the same force and effect as a handwritten signature. You consent to the use of electronic means to enter into this Agreement and acknowledge that this Agreement is enforceable as if executed in a traditional written format.
+
+## 11. General Provisions
+
+**Entire Agreement.** This Agreement constitutes the entire agreement between You and the Project Owner with respect to Your Contributions and supersedes all prior or contemporaneous understandings regarding such subject matter.
+
+**Severability.** If any provision of this Agreement is held to be unenforceable or invalid, that provision will be enforced to the maximum extent possible and the remaining provisions will remain in full force and effect.
+
+**No Waiver.** The failure of the Project Owner to enforce any provision of this Agreement shall not constitute a waiver of that provision or any other provision.
+
+**Amendment.** This Agreement may only be modified by a written instrument signed by both parties. Modifications to this Agreement apply only to Contributions submitted after the modified Agreement is published and accepted by You. Prior Contributions remain governed by the version of the Agreement in effect at the time of submission.
+
+**Notification.** Notices under this Agreement shall be sent to the Project Owner via a GitHub issue on the Project repository. Notices are effective upon receipt.
+
+---
+
+**By signing this CLA, you acknowledge that you have read and understood this Agreement and agree to be bound by its terms.**
+
+To sign, reply in the pull request with:
+
+> I have read the CLA Document and I hereby sign the CLA
diff --git a/LICENSE b/LICENSE
index 7a94ca9d..153d416d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,165 @@
-MIT License
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
-Copyright (c) 2025 Wesley Liddick
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
\ No newline at end of file
diff --git a/Makefile b/Makefile
index fd6a5a9a..d00d0c4f 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,12 @@
-.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
+.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan
+
+FRONTEND_CRITICAL_VITEST := \
+ src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \
+ src/views/auth/__tests__/WechatCallbackView.spec.ts \
+ src/views/user/__tests__/PaymentView.spec.ts \
+ src/views/user/__tests__/PaymentResultView.spec.ts \
+ src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
+ src/views/admin/__tests__/SettingsView.spec.ts
# 一键编译前后端
build: build-backend build-frontend
@@ -24,6 +32,10 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
+ @$(MAKE) test-frontend-critical
+
+test-frontend-critical:
+ @pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST)
test-datamanagementd:
@cd datamanagement && go test ./...
diff --git a/README.md b/README.md
index 74ab9af2..3e609d65 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
+
+
+Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.
+
+
## Ecosystem
@@ -618,7 +623,9 @@ sub2api/
## License
-MIT License
+This project is licensed under the [GNU Lesser General Public License v3.0](LICENSE) (or later).
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_CN.md b/README_CN.md
index c701372c..add32a17 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -95,6 +95,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
+
+
+感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。
+
+
## 生态项目
@@ -679,7 +684,9 @@ sub2api/
## 许可证
-MIT License
+本项目基于 [GNU 宽通用公共许可证 v3.0](LICENSE)(或更高版本)授权。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_JA.md b/README_JA.md
index 0d4db616..ccd595b9 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -95,6 +95,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
+
+
+Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。
+
+
## エコシステム
@@ -617,7 +622,9 @@ sub2api/
## ライセンス
-MIT License
+本プロジェクトは [GNU Lesser General Public License v3.0](LICENSE)(またはそれ以降のバージョン)の下でライセンスされています。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/assets/partners/logos/bestproxy.png b/assets/partners/logos/bestproxy.png
new file mode 100644
index 00000000..87c58670
Binary files /dev/null and b/assets/partners/logos/bestproxy.png differ
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index c29f5f75..72400828 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.114
+0.1.115
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 3853b251..0ef63a07 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -79,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService, emailService, emailCache)
+ userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go
new file mode 100644
index 00000000..5ccfcf19
--- /dev/null
+++ b/backend/ent/authidentity.go
@@ -0,0 +1,266 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentity is the model entity for the AuthIdentity schema.
+type AuthIdentity struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // VerifiedAt holds the value of the "verified_at" field.
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ // Issuer holds the value of the "issuer" field.
+ Issuer *string `json:"issuer,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityQuery when eager-loading is set.
+ Edges AuthIdentityEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityEdges struct {
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // Channels holds the value of the channels edge.
+ Channels []*AuthIdentityChannel `json:"channels,omitempty"`
+ // AdoptionDecisions holds the value of the adoption_decisions edge.
+ AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// ChannelsOrErr returns the Channels value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) {
+ if e.loadedTypes[1] {
+ return e.Channels, nil
+ }
+ return nil, &NotLoadedError{edge: "channels"}
+}
+
+// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) {
+ if e.loadedTypes[2] {
+ return e.AdoptionDecisions, nil
+ }
+ return nil, &NotLoadedError{edge: "adoption_decisions"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentity) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentity.FieldID, authidentity.FieldUserID:
+ values[i] = new(sql.NullInt64)
+ case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer:
+ values[i] = new(sql.NullString)
+ case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentity fields.
+func (_m *AuthIdentity) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentity.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentity.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentity.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case authidentity.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentity.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentity.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case authidentity.FieldVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field verified_at", values[i])
+ } else if value.Valid {
+ _m.VerifiedAt = new(time.Time)
+ *_m.VerifiedAt = value.Time
+ }
+ case authidentity.FieldIssuer:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field issuer", values[i])
+ } else if value.Valid {
+ _m.Issuer = new(string)
+ *_m.Issuer = value.String
+ }
+ case authidentity.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentity) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryUser queries the "user" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryUser() *UserQuery {
+ return NewAuthIdentityClient(_m.config).QueryUser(_m)
+}
+
+// QueryChannels queries the "channels" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery {
+ return NewAuthIdentityClient(_m.config).QueryChannels(_m)
+}
+
+// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m)
+}
+
+// Update returns a builder for updating this AuthIdentity.
+// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne {
+ return NewAuthIdentityClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentity) Unwrap() *AuthIdentity {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentity is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentity) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentity(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.VerifiedAt; v != nil {
+ builder.WriteString("verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.Issuer; v != nil {
+ builder.WriteString("issuer=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentities is a parsable slice of AuthIdentity.
+type AuthIdentities []*AuthIdentity
diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go
new file mode 100644
index 00000000..c90be759
--- /dev/null
+++ b/backend/ent/authidentity/authidentity.go
@@ -0,0 +1,209 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentity type in the database.
+ Label = "auth_identity"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldVerifiedAt holds the string denoting the verified_at field in the database.
+ FieldVerifiedAt = "verified_at"
+ // FieldIssuer holds the string denoting the issuer field in the database.
+ FieldIssuer = "issuer"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // EdgeChannels holds the string denoting the channels edge name in mutations.
+ EdgeChannels = "channels"
+ // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations.
+ EdgeAdoptionDecisions = "adoption_decisions"
+ // Table holds the table name of the authidentity in the database.
+ Table = "auth_identities"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "auth_identities"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+ // ChannelsTable is the table that holds the channels relation/edge.
+ ChannelsTable = "auth_identity_channels"
+ // ChannelsInverseTable is the table name for the AuthIdentityChannel entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package.
+ ChannelsInverseTable = "auth_identity_channels"
+ // ChannelsColumn is the table column denoting the channels relation/edge.
+ ChannelsColumn = "identity_id"
+ // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge.
+ AdoptionDecisionsTable = "identity_adoption_decisions"
+ // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionsInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge.
+ AdoptionDecisionsColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentity fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldUserID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldVerifiedAt,
+ FieldIssuer,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentity queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByVerifiedAt orders the results by the verified_at field.
+func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc()
+}
+
+// ByIssuer orders the results by the issuer field.
+func ByIssuer(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIssuer, opts...).ToFunc()
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByChannelsCount orders the results by channels count.
+func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...)
+ }
+}
+
+// ByChannels orders the results by channels terms.
+func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByAdoptionDecisionsCount orders the results by adoption_decisions count.
+func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...)
+ }
+}
+
+// ByAdoptionDecisions orders the results by adoption_decisions terms.
+func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
+func newChannelsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(ChannelsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+}
+func newAdoptionDecisionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+}
diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go
new file mode 100644
index 00000000..3dbf3178
--- /dev/null
+++ b/backend/ent/authidentity/where.go
@@ -0,0 +1,600 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ.
+func VerifiedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ.
+func Issuer(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// VerifiedAtEQ applies the EQ predicate on the "verified_at" field.
+func VerifiedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field.
+func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIn applies the In predicate on the "verified_at" field.
+func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field.
+func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtGT applies the GT predicate on the "verified_at" field.
+func VerifiedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtGTE applies the GTE predicate on the "verified_at" field.
+func VerifiedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLT applies the LT predicate on the "verified_at" field.
+func VerifiedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLTE applies the LTE predicate on the "verified_at" field.
+func VerifiedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field.
+func VerifiedAtIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt))
+}
+
+// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field.
+func VerifiedAtNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt))
+}
+
+// IssuerEQ applies the EQ predicate on the "issuer" field.
+func IssuerEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// IssuerNEQ applies the NEQ predicate on the "issuer" field.
+func IssuerNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v))
+}
+
+// IssuerIn applies the In predicate on the "issuer" field.
+func IssuerIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...))
+}
+
+// IssuerNotIn applies the NotIn predicate on the "issuer" field.
+func IssuerNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...))
+}
+
+// IssuerGT applies the GT predicate on the "issuer" field.
+func IssuerGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v))
+}
+
+// IssuerGTE applies the GTE predicate on the "issuer" field.
+func IssuerGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v))
+}
+
+// IssuerLT applies the LT predicate on the "issuer" field.
+func IssuerLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v))
+}
+
+// IssuerLTE applies the LTE predicate on the "issuer" field.
+func IssuerLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v))
+}
+
+// IssuerContains applies the Contains predicate on the "issuer" field.
+func IssuerContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v))
+}
+
+// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field.
+func IssuerHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v))
+}
+
+// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field.
+func IssuerHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v))
+}
+
+// IssuerIsNil applies the IsNil predicate on the "issuer" field.
+func IssuerIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer))
+}
+
+// IssuerNotNil applies the NotNil predicate on the "issuer" field.
+func IssuerNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer))
+}
+
+// IssuerEqualFold applies the EqualFold predicate on the "issuer" field.
+func IssuerEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v))
+}
+
+// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field.
+func IssuerContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v))
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasChannels applies the HasEdge predicate on the "channels" edge.
+func HasChannels() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates).
+func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newChannelsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge.
+func HasAdoptionDecisions() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates).
+func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newAdoptionDecisionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go
new file mode 100644
index 00000000..e287705c
--- /dev/null
+++ b/backend/ent/authidentity_create.go
@@ -0,0 +1,1036 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityCreate is the builder for creating a AuthIdentity entity.
+type AuthIdentityCreate struct {
+ config
+ mutation *AuthIdentityMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetVerifiedAt(v)
+ return _c
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetIssuer sets the "issuer" field.
+func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate {
+ _c.mutation.SetIssuer(v)
+ return _c
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetIssuer(*v)
+ }
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddChannelIDs(ids...)
+ return _c
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddAdoptionDecisionIDs(ids...)
+ return _c
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentity in the database.
+func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentity.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentity.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentity.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)}
+ }
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentity{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ _node.VerifiedAt = &value
+ }
+ if value, ok := _c.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ _node.Issuer = &value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentity node.
+ AuthIdentityUpsertOne struct {
+ create *AuthIdentityCreate
+ }
+
+ // AuthIdentityUpsert is the "OnConflict" setter.
+ AuthIdentityUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUpdatedAt)
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUserID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderSubject)
+ return u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldVerifiedAt, v)
+ return u
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldIssuer, v)
+ return u
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldIssuer)
+ return u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldIssuer)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk.
+type AuthIdentityCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentity entities in the database.
+func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentity, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentity nodes.
+type AuthIdentityUpsertBulk struct {
+ create *AuthIdentityCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go
new file mode 100644
index 00000000..4f1f6f3c
--- /dev/null
+++ b/backend/ent/authidentity_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityDelete is the builder for deleting a AuthIdentity entity.
+type AuthIdentityDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity.
+type AuthIdentityDeleteOne struct {
+ _d *AuthIdentityDelete
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentity.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go
new file mode 100644
index 00000000..ff27ef3c
--- /dev/null
+++ b/backend/ent/authidentity_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityQuery is the builder for querying AuthIdentity entities.
+type AuthIdentityQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentity.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentity
+ withUser *UserQuery
+ withChannels *AuthIdentityChannelQuery
+ withAdoptionDecisions *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityQuery builder.
+func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *AuthIdentityQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryChannels chains the current query on the "channels" edge.
+func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge.
+func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentity entity from the query.
+// Returns a *NotFoundError when no AuthIdentity was found.
+func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentity.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentity ID from the query.
+// Returns a *NotFoundError when no AuthIdentity ID was found.
+func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentity.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentity entity is found.
+// Returns a *NotFoundError when no AuthIdentity entities are found.
+func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentity.Label}
+ default:
+ return nil, &NotSingularError{authidentity.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentity ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentity ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentity.Label}
+ default:
+ err = &NotSingularError{authidentity.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentities.
+func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]()
+ return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentity IDs.
+func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentity.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentity{}, _q.predicates...),
+ withUser: _q.withUser.Clone(),
+ withChannels: _q.withChannels.Clone(),
+ withAdoptionDecisions: _q.withAdoptionDecisions.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// WithChannels tells the query-builder to eager-load the nodes that are connected to
+// the "channels" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withChannels = query
+ return _q
+}
+
+// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecisions = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// GroupBy(authidentity.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentity.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// Select(authidentity.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q}
+ sbuild.label = authidentity.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentitySelect configured with the given aggregations.
+func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentity.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) {
+ var (
+ nodes = []*AuthIdentity{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withUser != nil,
+ _q.withChannels != nil,
+ _q.withAdoptionDecisions != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentity).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentity{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withChannels; query != nil {
+ if err := _q.loadChannels(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} },
+ func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecisions; query != nil {
+ if err := _q.loadAdoptionDecisions(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} },
+ func(n *AuthIdentity, e *IdentityAdoptionDecision) {
+ n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentity)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID)
+ }
+ query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for i := range fields {
+ if fields[i] != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(authidentity.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentity.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentity.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities.
+type AuthIdentityGroupBy struct {
+ selector
+ build *AuthIdentityQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities.
+type AuthIdentitySelect struct {
+ *AuthIdentityQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go
new file mode 100644
index 00000000..c457470b
--- /dev/null
+++ b/backend/ent/authidentity_update.go
@@ -0,0 +1,923 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityUpdate is the builder for updating AuthIdentity entities.
+type AuthIdentityUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity.
+type AuthIdentityUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentity entity.
+func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for _, f := range fields {
+ if !authidentity.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentity{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go
new file mode 100644
index 00000000..1ff3e5d1
--- /dev/null
+++ b/backend/ent/authidentitychannel.go
@@ -0,0 +1,228 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema.
+type AuthIdentityChannel struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID int64 `json:"identity_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // Channel holds the value of the "channel" field.
+ Channel string `json:"channel,omitempty"`
+ // ChannelAppID holds the value of the "channel_app_id" field.
+ ChannelAppID string `json:"channel_app_id,omitempty"`
+ // ChannelSubject holds the value of the "channel_subject" field.
+ ChannelSubject string `json:"channel_subject,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set.
+ Edges AuthIdentityChannelEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityChannelEdges struct {
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject:
+ values[i] = new(sql.NullString)
+ case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentityChannel fields.
+func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentitychannel.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentitychannel.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentitychannel.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = value.Int64
+ }
+ case authidentitychannel.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentitychannel.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentitychannel.FieldChannel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel", values[i])
+ } else if value.Valid {
+ _m.Channel = value.String
+ }
+ case authidentitychannel.FieldChannelAppID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_app_id", values[i])
+ } else if value.Valid {
+ _m.ChannelAppID = value.String
+ }
+ case authidentitychannel.FieldChannelSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_subject", values[i])
+ } else if value.Valid {
+ _m.ChannelSubject = value.String
+ }
+ case authidentitychannel.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity.
+func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery {
+ return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this AuthIdentityChannel.
+// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne {
+ return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentityChannel is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentityChannel) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentityChannel(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IdentityID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("channel=")
+ builder.WriteString(_m.Channel)
+ builder.WriteString(", ")
+ builder.WriteString("channel_app_id=")
+ builder.WriteString(_m.ChannelAppID)
+ builder.WriteString(", ")
+ builder.WriteString("channel_subject=")
+ builder.WriteString(_m.ChannelSubject)
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentityChannels is a parsable slice of AuthIdentityChannel.
+type AuthIdentityChannels []*AuthIdentityChannel
diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go
new file mode 100644
index 00000000..7dcc98bb
--- /dev/null
+++ b/backend/ent/authidentitychannel/authidentitychannel.go
@@ -0,0 +1,153 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentitychannel type in the database.
+ Label = "auth_identity_channel"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldChannel holds the string denoting the channel field in the database.
+ FieldChannel = "channel"
+ // FieldChannelAppID holds the string denoting the channel_app_id field in the database.
+ FieldChannelAppID = "channel_app_id"
+ // FieldChannelSubject holds the string denoting the channel_subject field in the database.
+ FieldChannelSubject = "channel_subject"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the authidentitychannel in the database.
+ Table = "auth_identity_channels"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "auth_identity_channels"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentitychannel fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldIdentityID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldChannel,
+ FieldChannelAppID,
+ FieldChannelSubject,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ ChannelValidator func(string) error
+ // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ ChannelAppIDValidator func(string) error
+ // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ ChannelSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentityChannel queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByChannel orders the results by the channel field.
+func ByChannel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannel, opts...).ToFunc()
+}
+
+// ByChannelAppID orders the results by the channel_app_id field.
+func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelAppID, opts...).ToFunc()
+}
+
+// ByChannelSubject orders the results by the channel_subject field.
+func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelSubject, opts...).ToFunc()
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go
new file mode 100644
index 00000000..827dc384
--- /dev/null
+++ b/backend/ent/authidentitychannel/where.go
@@ -0,0 +1,559 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ.
+func Channel(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ.
+func ChannelAppID(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ.
+func ChannelSubject(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ChannelEQ applies the EQ predicate on the "channel" field.
+func ChannelEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelNEQ applies the NEQ predicate on the "channel" field.
+func ChannelNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v))
+}
+
+// ChannelIn applies the In predicate on the "channel" field.
+func ChannelIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...))
+}
+
+// ChannelNotIn applies the NotIn predicate on the "channel" field.
+func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...))
+}
+
+// ChannelGT applies the GT predicate on the "channel" field.
+func ChannelGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v))
+}
+
+// ChannelGTE applies the GTE predicate on the "channel" field.
+func ChannelGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v))
+}
+
+// ChannelLT applies the LT predicate on the "channel" field.
+func ChannelLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v))
+}
+
+// ChannelLTE applies the LTE predicate on the "channel" field.
+func ChannelLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v))
+}
+
+// ChannelContains applies the Contains predicate on the "channel" field.
+func ChannelContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v))
+}
+
+// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field.
+func ChannelHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v))
+}
+
+// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field.
+func ChannelHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v))
+}
+
+// ChannelEqualFold applies the EqualFold predicate on the "channel" field.
+func ChannelEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v))
+}
+
+// ChannelContainsFold applies the ContainsFold predicate on the "channel" field.
+func ChannelContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v))
+}
+
+// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field.
+func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field.
+func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDIn applies the In predicate on the "channel_app_id" field.
+func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field.
+func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field.
+func ChannelAppIDGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field.
+func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field.
+func ChannelAppIDLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field.
+func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field.
+func ChannelAppIDContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field.
+func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field.
+func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field.
+func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field.
+func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v))
+}
+
+// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field.
+func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field.
+func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectIn applies the In predicate on the "channel_subject" field.
+func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field.
+func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectGT applies the GT predicate on the "channel_subject" field.
+func ChannelSubjectGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field.
+func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLT applies the LT predicate on the "channel_subject" field.
+func ChannelSubjectLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field.
+func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field.
+func ChannelSubjectContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field.
+func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field.
+func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field.
+func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field.
+func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v))
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go
new file mode 100644
index 00000000..4ce28479
--- /dev/null
+++ b/backend/ent/authidentitychannel_create.go
@@ -0,0 +1,932 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity.
+type AuthIdentityChannelCreate struct {
+ config
+ mutation *AuthIdentityChannelMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetChannel sets the "channel" field.
+func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannel(v)
+ return _c
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelAppID(v)
+ return _c
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelSubject(v)
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentityChannel in the database.
+func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityChannelCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentitychannel.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentitychannel.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityChannelCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)}
+ }
+ if _, ok := _c.mutation.IdentityID(); !ok {
+ return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Channel(); !ok {
+ return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)}
+ }
+ if v, ok := _c.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelAppID(); !ok {
+ return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)}
+ }
+ if v, ok := _c.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelSubject(); !ok {
+ return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)}
+ }
+ if v, ok := _c.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)}
+ }
+ if len(_c.mutation.IdentityIDs()) == 0 {
+ return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentityChannel{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ _node.Channel = value
+ }
+ if value, ok := _c.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ _node.ChannelAppID = value
+ }
+ if value, ok := _c.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ _node.ChannelSubject = value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentityChannel node.
+ AuthIdentityChannelUpsertOne struct {
+ create *AuthIdentityChannelCreate
+ }
+
+ // AuthIdentityChannelUpsert is the "OnConflict" setter.
+ AuthIdentityChannelUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldUpdatedAt)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldIdentityID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderKey)
+ return u
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannel, v)
+ return u
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannel)
+ return u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelAppID, v)
+ return u
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelAppID)
+ return u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelSubject, v)
+ return u
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelSubject)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk.
+type AuthIdentityChannelCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityChannelCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentityChannel entities in the database.
+func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentityChannel, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityChannelMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentityChannel nodes.
+type AuthIdentityChannelUpsertBulk struct {
+ create *AuthIdentityChannelCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go
new file mode 100644
index 00000000..1a4acac5
--- /dev/null
+++ b/backend/ent/authidentitychannel_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity.
+type AuthIdentityChannelDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity.
+type AuthIdentityChannelDeleteOne struct {
+ _d *AuthIdentityChannelDelete
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go
new file mode 100644
index 00000000..7a202b7f
--- /dev/null
+++ b/backend/ent/authidentitychannel_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities.
+type AuthIdentityChannelQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentitychannel.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentityChannel
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityChannelQuery builder.
+func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentityChannel entity from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel was found.
+func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentitychannel.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentityChannel ID from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel ID was found.
+func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentitychannel.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found.
+// Returns a *NotFoundError when no AuthIdentityChannel entities are found.
+func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil, &NotSingularError{authidentitychannel.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentitychannel.Label}
+ default:
+ err = &NotSingularError{authidentitychannel.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentityChannels.
+func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]()
+ return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentityChannel IDs.
+func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityChannelQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentitychannel.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// GroupBy(authidentitychannel.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityChannelGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentitychannel.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// Select(authidentitychannel.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q}
+ sbuild.label = authidentitychannel.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations.
+func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) {
+ var (
+ nodes = []*AuthIdentityChannel{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentityChannel).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentityChannel{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentityChannel)
+ for i := range nodes {
+ fk := nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for i := range fields {
+ if fields[i] != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentitychannel.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentitychannel.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities.
+type AuthIdentityChannelGroupBy struct {
+ selector
+ build *AuthIdentityChannelQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities.
+type AuthIdentityChannelSelect struct {
+ *AuthIdentityChannelQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go
new file mode 100644
index 00000000..b550c454
--- /dev/null
+++ b/backend/ent/authidentitychannel_update.go
@@ -0,0 +1,581 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities.
+type AuthIdentityChannelUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity.
+type AuthIdentityChannelUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentityChannel entity.
+func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for _, f := range fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentityChannel{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index e52e015a..b02f519b 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -20,12 +20,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -60,18 +64,26 @@ type Client struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -118,12 +130,16 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
+ c.AuthIdentity = NewAuthIdentityClient(c.config)
+ c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
+ c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config)
c.PaymentAuditLog = NewPaymentAuditLogClient(c.config)
c.PaymentOrder = NewPaymentOrderClient(c.config)
c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config)
+ c.PendingAuthSession = NewPendingAuthSessionClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
@@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -332,11 +356,12 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
+ case *AuthIdentityMutation:
+ return c.AuthIdentity.mutate(ctx, m)
+ case *AuthIdentityChannelMutation:
+ return c.AuthIdentityChannel.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *IdempotencyRecordMutation:
return c.IdempotencyRecord.mutate(ctx, m)
+ case *IdentityAdoptionDecisionMutation:
+ return c.IdentityAdoptionDecision.mutate(ctx, m)
case *PaymentAuditLogMutation:
return c.PaymentAuditLog.mutate(ctx, m)
case *PaymentOrderMutation:
return c.PaymentOrder.mutate(ctx, m)
case *PaymentProviderInstanceMutation:
return c.PaymentProviderInstance.mutate(ctx, m)
+ case *PendingAuthSessionMutation:
+ return c.PendingAuthSession.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
}
}
+// AuthIdentityClient is a client for the AuthIdentity schema.
+type AuthIdentityClient struct {
+ config
+}
+
+// NewAuthIdentityClient returns a client for the AuthIdentity from the given config.
+func NewAuthIdentityClient(c config) *AuthIdentityClient {
+ return &AuthIdentityClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`.
+func (c *AuthIdentityClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`.
+func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentity entity.
+func (c *AuthIdentityClient) Create() *AuthIdentityCreate {
+ mutation := newAuthIdentityMutation(c.config, OpCreate)
+ return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentity entities.
+func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk {
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentity.
+func (c *AuthIdentityClient) Update() *AuthIdentityUpdate {
+ mutation := newAuthIdentityMutation(c.config, OpUpdate)
+ return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentity.
+func (c *AuthIdentityClient) Delete() *AuthIdentityDelete {
+ mutation := newAuthIdentityMutation(c.config, OpDelete)
+ return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne {
+ builder := c.Delete().Where(authidentity.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentity.
+func (c *AuthIdentityClient) Query() *AuthIdentityQuery {
+ return &AuthIdentityQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentity},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentity entity by its id.
+func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) {
+ return c.Query().Where(authidentity.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryUser queries the user edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryChannels queries the channels edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityClient) Hooks() []Hook {
+ return c.hooks.AuthIdentity
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentity
+}
+
+func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op())
+ }
+}
+
+// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema.
+type AuthIdentityChannelClient struct {
+ config
+}
+
+// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config.
+func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient {
+ return &AuthIdentityChannelClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentityChannel entity.
+func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpCreate)
+ return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities.
+func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk {
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityChannelCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdate)
+ return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete {
+ mutation := newAuthIdentityChannelMutation(c.config, OpDelete)
+ return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne {
+ builder := c.Delete().Where(authidentitychannel.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityChannelDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery {
+ return &AuthIdentityChannelQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentityChannel},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentityChannel entity by its id.
+func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) {
+ return c.Query().Where(authidentitychannel.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryIdentity queries the identity edge of a AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityChannelClient) Hooks() []Hook {
+ return c.hooks.AuthIdentityChannel
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityChannelClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentityChannel
+}
+
+func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op())
+ }
+}
+
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco
}
}
+// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecisionClient struct {
+ config
+}
+
+// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config.
+func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient {
+ return &IdentityAdoptionDecisionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) {
+ c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...)
+}
+
+// Create returns a builder for creating a IdentityAdoptionDecision entity.
+func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate)
+ return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities.
+func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk {
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*IdentityAdoptionDecisionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate)
+ return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete)
+ return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne {
+ builder := c.Delete().Where(identityadoptiondecision.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &IdentityAdoptionDecisionDeleteOne{builder}
+}
+
+// Query returns a query builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery {
+ return &IdentityAdoptionDecisionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeIdentityAdoptionDecision},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a IdentityAdoptionDecision entity by its id.
+func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) {
+ return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryIdentity queries the identity edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *IdentityAdoptionDecisionClient) Hooks() []Hook {
+ return c.hooks.IdentityAdoptionDecision
+}
+
+// Interceptors returns the client interceptors.
+func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor {
+ return c.inters.IdentityAdoptionDecision
+}
+
+func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op())
+ }
+}
+
// PaymentAuditLogClient is a client for the PaymentAuditLog schema.
type PaymentAuditLogClient struct {
config
@@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr
}
}
+// PendingAuthSessionClient is a client for the PendingAuthSession schema.
+type PendingAuthSessionClient struct {
+ config
+}
+
+// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config.
+func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient {
+ return &PendingAuthSessionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`.
+func (c *PendingAuthSessionClient) Use(hooks ...Hook) {
+ c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`.
+func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...)
+}
+
+// Create returns a builder for creating a PendingAuthSession entity.
+func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate {
+ mutation := newPendingAuthSessionMutation(c.config, OpCreate)
+ return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities.
+func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk {
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PendingAuthSessionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdate)
+ return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete {
+ mutation := newPendingAuthSessionMutation(c.config, OpDelete)
+ return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne {
+ builder := c.Delete().Where(pendingauthsession.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PendingAuthSessionDeleteOne{builder}
+}
+
+// Query returns a query builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery {
+ return &PendingAuthSessionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePendingAuthSession},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PendingAuthSession entity by its id.
+func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) {
+ return c.Query().Where(pendingauthsession.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryTargetUser queries the target_user edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *PendingAuthSessionClient) Hooks() []Hook {
+ return c.hooks.PendingAuthSession
+}
+
+// Interceptors returns the client interceptors.
+func (c *PendingAuthSessionClient) Interceptors() []Interceptor {
+ return c.inters.PendingAuthSession
+}
+
+func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op())
+ }
+}
+
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities queries the auth_identities edge of a User.
+func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User.
+func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Interceptor
}
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 96ed5e03..339e5369 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -98,32 +102,36 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- paymentauditlog.Table: paymentauditlog.ValidColumn,
- paymentorder.Table: paymentorder.ValidColumn,
- paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- subscriptionplan.Table: subscriptionplan.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index 199dacea..46ac02bc 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentity mutator.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentityChannel mutator.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
@@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary
+// function as IdentityAdoptionDecision mutator.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary
// function as PaymentAuditLog mutator.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error)
@@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary
+// function as PendingAuthSession mutator.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PendingAuthSessionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go
new file mode 100644
index 00000000..ecaee65c
--- /dev/null
+++ b/backend/ent/identityadoptiondecision.go
@@ -0,0 +1,223 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecision struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // PendingAuthSessionID holds the value of the "pending_auth_session_id" field.
+ PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID *int64 `json:"identity_id,omitempty"`
+ // AdoptDisplayName holds the value of the "adopt_display_name" field.
+ AdoptDisplayName bool `json:"adopt_display_name,omitempty"`
+ // AdoptAvatar holds the value of the "adopt_avatar" field.
+ AdoptAvatar bool `json:"adopt_avatar,omitempty"`
+ // DecidedAt holds the value of the "decided_at" field.
+ DecidedAt time.Time `json:"decided_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set.
+ Edges IdentityAdoptionDecisionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph.
+type IdentityAdoptionDecisionEdges struct {
+ // PendingAuthSession holds the value of the pending_auth_session edge.
+ PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"`
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) {
+ if e.PendingAuthSession != nil {
+ return e.PendingAuthSession, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: pendingauthsession.Label}
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_session"}
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar:
+ values[i] = new(sql.NullBool)
+ case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the IdentityAdoptionDecision fields.
+func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case identityadoptiondecision.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i])
+ } else if value.Valid {
+ _m.PendingAuthSessionID = value.Int64
+ }
+ case identityadoptiondecision.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = new(int64)
+ *_m.IdentityID = value.Int64
+ }
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i])
+ } else if value.Valid {
+ _m.AdoptDisplayName = value.Bool
+ }
+ case identityadoptiondecision.FieldAdoptAvatar:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i])
+ } else if value.Valid {
+ _m.AdoptAvatar = value.Bool
+ }
+ case identityadoptiondecision.FieldDecidedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field decided_at", values[i])
+ } else if value.Valid {
+ _m.DecidedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision.
+// This includes values selected through modifiers, order, etc.
+func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m)
+}
+
+// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this IdentityAdoptionDecision.
+// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne {
+ return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: IdentityAdoptionDecision is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *IdentityAdoptionDecision) String() string {
+ var builder strings.Builder
+ builder.WriteString("IdentityAdoptionDecision(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("pending_auth_session_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID))
+ builder.WriteString(", ")
+ if v := _m.IdentityID; v != nil {
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("adopt_display_name=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName))
+ builder.WriteString(", ")
+ builder.WriteString("adopt_avatar=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar))
+ builder.WriteString(", ")
+ builder.WriteString("decided_at=")
+ builder.WriteString(_m.DecidedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision.
+type IdentityAdoptionDecisions []*IdentityAdoptionDecision
diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
new file mode 100644
index 00000000..93adaf73
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
@@ -0,0 +1,159 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the identityadoptiondecision type in the database.
+ Label = "identity_adoption_decision"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database.
+ FieldPendingAuthSessionID = "pending_auth_session_id"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database.
+ FieldAdoptDisplayName = "adopt_display_name"
+ // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database.
+ FieldAdoptAvatar = "adopt_avatar"
+ // FieldDecidedAt holds the string denoting the decided_at field in the database.
+ FieldDecidedAt = "decided_at"
+ // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations.
+ EdgePendingAuthSession = "pending_auth_session"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the identityadoptiondecision in the database.
+ Table = "identity_adoption_decisions"
+ // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge.
+ PendingAuthSessionTable = "identity_adoption_decisions"
+ // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge.
+ PendingAuthSessionColumn = "pending_auth_session_id"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "identity_adoption_decisions"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for identityadoptiondecision fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldPendingAuthSessionID,
+ FieldIdentityID,
+ FieldAdoptDisplayName,
+ FieldAdoptAvatar,
+ FieldDecidedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field.
+ DefaultAdoptDisplayName bool
+ // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field.
+ DefaultAdoptAvatar bool
+ // DefaultDecidedAt holds the default value on creation for the "decided_at" field.
+ DefaultDecidedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the IdentityAdoptionDecision queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionID orders the results by the pending_auth_session_id field.
+func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByAdoptDisplayName orders the results by the adopt_display_name field.
+func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc()
+}
+
+// ByAdoptAvatar orders the results by the adopt_avatar field.
+func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc()
+}
+
+// ByDecidedAt orders the results by the decided_at field.
+func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDecidedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionField orders the results by pending_auth_session field.
+func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newPendingAuthSessionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go
new file mode 100644
index 00000000..1968f175
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/where.go
@@ -0,0 +1,342 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ.
+func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ.
+func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ.
+func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ.
+func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...))
+}
+
+// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field.
+func IdentityIDIsNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID))
+}
+
+// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field.
+func IdentityIDNotNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID))
+}
+
+// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field.
+func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field.
+func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAtEQ applies the EQ predicate on the "decided_at" field.
+func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field.
+func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtIn applies the In predicate on the "decided_at" field.
+func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field.
+func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtGT applies the GT predicate on the "decided_at" field.
+func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v))
+}
+
+// DecidedAtGTE applies the GTE predicate on the "decided_at" field.
+func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v))
+}
+
+// DecidedAtLT applies the LT predicate on the "decided_at" field.
+func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v))
+}
+
+// DecidedAtLTE applies the LTE predicate on the "decided_at" field.
+func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v))
+}
+
+// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge.
+func HasPendingAuthSession() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates).
+func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newPendingAuthSessionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.NotPredicates(p))
+}
diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go
new file mode 100644
index 00000000..491ba9f9
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_create.go
@@ -0,0 +1,843 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionCreate struct {
+ config
+ mutation *IdentityAdoptionDecisionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetPendingAuthSessionID(v)
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetIdentityID(*v)
+ }
+ return _c
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptDisplayName(v)
+ return _c
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptDisplayName(*v)
+ }
+ return _c
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptAvatar(v)
+ return _c
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptAvatar(*v)
+ }
+ return _c
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetDecidedAt(v)
+ return _c
+}
+
+// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetDecidedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate {
+ return _c.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _c.mutation
+}
+
+// Save creates the IdentityAdoptionDecision in the database.
+func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *IdentityAdoptionDecisionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := identityadoptiondecision.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ v := identityadoptiondecision.DefaultAdoptDisplayName
+ _c.mutation.SetAdoptDisplayName(v)
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ v := identityadoptiondecision.DefaultAdoptAvatar
+ _c.mutation.SetAdoptAvatar(v)
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ v := identityadoptiondecision.DefaultDecidedAt()
+ _c.mutation.SetDecidedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *IdentityAdoptionDecisionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)}
+ }
+ if _, ok := _c.mutation.PendingAuthSessionID(); !ok {
+ return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)}
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)}
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)}
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)}
+ }
+ if len(_c.mutation.PendingAuthSessionIDs()) == 0 {
+ return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)}
+ }
+ return nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) {
+ var (
+ _node = &IdentityAdoptionDecision{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ _node.AdoptDisplayName = value
+ }
+ if value, ok := _c.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ _node.AdoptAvatar = value
+ }
+ if value, ok := _c.mutation.DecidedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value)
+ _node.DecidedAt = value
+ }
+ if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.PendingAuthSessionID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing
+ // one IdentityAdoptionDecision node.
+ IdentityAdoptionDecisionUpsertOne struct {
+ create *IdentityAdoptionDecisionCreate
+ }
+
+ // IdentityAdoptionDecisionUpsert is the "OnConflict" setter.
+ IdentityAdoptionDecisionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldUpdatedAt)
+ return u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v)
+ return u
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetNull(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptDisplayName, v)
+ return u
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName)
+ return u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptAvatar, v)
+ return u
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := u.create.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk.
+type IdentityAdoptionDecisionCreateBulk struct {
+ config
+ err error
+ builders []*IdentityAdoptionDecisionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the IdentityAdoptionDecision entities in the database.
+func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*IdentityAdoptionDecision, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*IdentityAdoptionDecisionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing
+// a bulk of IdentityAdoptionDecision nodes.
+type IdentityAdoptionDecisionUpsertBulk struct {
+ create *IdentityAdoptionDecisionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := b.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go
new file mode 100644
index 00000000..ef3d328d
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDelete struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDeleteOne struct {
+ _d *IdentityAdoptionDecisionDelete
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go
new file mode 100644
index 00000000..4082d8ee
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_query.go
@@ -0,0 +1,721 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionQuery struct {
+ config
+ ctx *QueryContext
+ order []identityadoptiondecision.OrderOption
+ inters []Interceptor
+ predicates []predicate.IdentityAdoptionDecision
+ withPendingAuthSession *PendingAuthSessionQuery
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder.
+func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first IdentityAdoptionDecision entity from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision was found.
+func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first IdentityAdoptionDecision ID from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found.
+func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found.
+// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found.
+func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil, &NotSingularError{identityadoptiondecision.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{identityadoptiondecision.Label}
+ default:
+ err = &NotSingularError{identityadoptiondecision.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of IdentityAdoptionDecisions.
+func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]()
+ return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of IdentityAdoptionDecision IDs.
+func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &IdentityAdoptionDecisionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]identityadoptiondecision.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...),
+ withPendingAuthSession: _q.withPendingAuthSession.Clone(),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSession = query
+ return _q
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// GroupBy(identityadoptiondecision.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &IdentityAdoptionDecisionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = identityadoptiondecision.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// Select(identityadoptiondecision.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q}
+ sbuild.label = identityadoptiondecision.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations.
+func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) {
+ var (
+ nodes = []*IdentityAdoptionDecision{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withPendingAuthSession != nil,
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*IdentityAdoptionDecision).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &IdentityAdoptionDecision{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withPendingAuthSession; query != nil {
+ if err := _q.loadPendingAuthSession(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ fk := nodes[i].PendingAuthSessionID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(pendingauthsession.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ if nodes[i].IdentityID == nil {
+ continue
+ }
+ fk := *nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for i := range fields {
+ if fields[i] != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withPendingAuthSession != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(identityadoptiondecision.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = identityadoptiondecision.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionGroupBy struct {
+ selector
+ build *IdentityAdoptionDecisionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionSelect struct {
+ *IdentityAdoptionDecisionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v)
+}
+
+func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go
new file mode 100644
index 00000000..0ca21d27
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_update.go
@@ -0,0 +1,532 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionUpdate struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdate) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated IdentityAdoptionDecision entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for _, f := range fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &IdentityAdoptionDecision{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 8d8320bb..157c5122 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -13,12 +13,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
+// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error)
@@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
+// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
+ case *ent.AuthIdentityQuery:
+ return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil
+ case *ent.AuthIdentityChannelQuery:
+ return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.IdempotencyRecordQuery:
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
+ case *ent.IdentityAdoptionDecisionQuery:
+ return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil
case *ent.PaymentAuditLogQuery:
return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil
case *ent.PaymentOrderQuery:
return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil
case *ent.PaymentProviderInstanceQuery:
return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil
+ case *ent.PendingAuthSessionQuery:
+ return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:
diff --git a/backend/ent/migrate/auth_identity_fk_ondelete_test.go b/backend/ent/migrate/auth_identity_fk_ondelete_test.go
new file mode 100644
index 00000000..0e37025a
--- /dev/null
+++ b/backend/ent/migrate/auth_identity_fk_ondelete_test.go
@@ -0,0 +1,73 @@
+package migrate
+
+import (
+ "testing"
+
+ "entgo.io/ent/dialect/entsql"
+ entschema "entgo.io/ent/dialect/sql/schema"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) {
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete,
+ )
+
+ require.Equal(
+ t,
+ entschema.SetNull,
+ findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.SetNull,
+ findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete,
+ )
+}
+
+func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) {
+ idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no")
+ require.True(t, idx.Unique)
+ require.Len(t, idx.Columns, 1)
+ require.Equal(t, "out_trade_no", idx.Columns[0].Name)
+ require.NotNil(t, idx.Annotation)
+ require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where)
+}
+
+func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey {
+ t.Helper()
+
+ for _, fk := range table.ForeignKeys {
+ if fk.Symbol == symbol {
+ return fk
+ }
+ }
+
+ require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol)
+ return nil
+}
+
+func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index {
+ t.Helper()
+
+ for _, idx := range table.Indexes {
+ if idx.Name == name {
+ return idx
+ }
+ }
+
+ require.Failf(t, "missing index", "table %s should include index %s", table.Name, name)
+ return nil
+}
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 68bdbf55..40b326a9 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -338,6 +338,89 @@ var (
},
},
}
+ // AuthIdentitiesColumns holds the columns for the "auth_identities" table.
+ AuthIdentitiesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // AuthIdentitiesTable holds the schema information for the "auth_identities" table.
+ AuthIdentitiesTable = &schema.Table{
+ Name: "auth_identities",
+ Columns: AuthIdentitiesColumns,
+ PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identities_users_auth_identities",
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentity_provider_type_provider_key_provider_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]},
+ },
+ {
+ Name: "authidentity_user_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ },
+ {
+ Name: "authidentity_user_id_provider_type",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]},
+ },
+ },
+ }
+ // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table.
+ AuthIdentityChannelsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel", Type: field.TypeString, Size: 20},
+ {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "identity_id", Type: field.TypeInt64},
+ }
+ // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table.
+ AuthIdentityChannelsTable = &schema.Table{
+ Name: "auth_identity_channels",
+ Columns: AuthIdentityChannelsColumns,
+ PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identity_channels_auth_identities_channels",
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]},
+ },
+ {
+ Name: "authidentitychannel_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ },
+ },
+ }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -485,6 +568,49 @@ var (
},
},
}
+ // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "adopt_display_name", Type: field.TypeBool, Default: false},
+ {Name: "adopt_avatar", Type: field.TypeBool, Default: false},
+ {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "identity_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true},
+ }
+ // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsTable = &schema.Table{
+ Name: "identity_adoption_decisions",
+ Columns: IdentityAdoptionDecisionsColumns,
+ PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ {
+ Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "identityadoptiondecision_pending_auth_session_id",
+ Unique: true,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ },
+ {
+ Name: "identityadoptiondecision_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ },
+ },
+ }
// PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table.
PaymentAuditLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -528,6 +654,8 @@ var (
{Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "subscription_days", Type: field.TypeInt, Nullable: true},
{Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64},
+ {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30},
+ {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"},
{Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
{Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
@@ -556,7 +684,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "payment_orders_users_payment_orders",
- Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -564,38 +692,41 @@ var (
Indexes: []*schema.Index{
{
Name: "paymentorder_out_trade_no",
- Unique: false,
+ Unique: true,
Columns: []*schema.Column{PaymentOrdersColumns[8]},
+ Annotation: &entsql.IndexAnnotation{
+ Where: "out_trade_no <> ''",
+ },
},
{
Name: "paymentorder_user_id",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
},
{
Name: "paymentorder_status",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[19]},
+ Columns: []*schema.Column{PaymentOrdersColumns[21]},
},
{
Name: "paymentorder_expires_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[27]},
+ Columns: []*schema.Column{PaymentOrdersColumns[29]},
},
{
Name: "paymentorder_created_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[35]},
+ Columns: []*schema.Column{PaymentOrdersColumns[37]},
},
{
Name: "paymentorder_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_payment_type_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_order_type",
@@ -638,6 +769,72 @@ var (
},
},
}
+ // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table.
+ PendingAuthSessionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "session_token", Type: field.TypeString, Size: 255},
+ {Name: "intent", Type: field.TypeString, Size: 40},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "target_user_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table.
+ PendingAuthSessionsTable = &schema.Table{
+ Name: "pending_auth_sessions",
+ Columns: PendingAuthSessionsColumns,
+ PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "pending_auth_sessions_users_pending_auth_sessions",
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "pendingauthsession_session_token",
+ Unique: true,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[3]},
+ },
+ {
+ Name: "pendingauthsession_target_user_id",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ },
+ {
+ Name: "pendingauthsession_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[19]},
+ },
+ {
+ Name: "pendingauthsession_provider_type_provider_key_provider_subject",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]},
+ },
+ {
+ Name: "pendingauthsession_completion_code_hash",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[14]},
+ },
+ },
+ }
// PromoCodesColumns holds the columns for the "promo_codes" table.
PromoCodesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -1079,6 +1276,9 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"},
+ {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -1318,12 +1518,16 @@ var (
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
+ AuthIdentitiesTable,
+ AuthIdentityChannelsTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
+ IdentityAdoptionDecisionsTable,
PaymentAuditLogsTable,
PaymentOrdersTable,
PaymentProviderInstancesTable,
+ PendingAuthSessionsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
@@ -1365,6 +1569,14 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
+ AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable
+ AuthIdentitiesTable.Annotation = &entsql.Annotation{
+ Table: "auth_identities",
+ }
+ AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
+ Table: "auth_identity_channels",
+ }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
@@ -1374,6 +1586,11 @@ func init() {
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
Table: "idempotency_records",
}
+ IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable
+ IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{
+ Table: "identity_adoption_decisions",
+ }
PaymentAuditLogsTable.Annotation = &entsql.Annotation{
Table: "payment_audit_logs",
}
@@ -1384,6 +1601,10 @@ func init() {
PaymentProviderInstancesTable.Annotation = &entsql.Annotation{
Table: "payment_provider_instances",
}
+ PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable
+ PendingAuthSessionsTable.Annotation = &entsql.Annotation{
+ Table: "pending_auth_sessions",
+ }
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 524ccb92..ec4a4070 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -51,32 +55,36 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypePaymentAuditLog = "PaymentAuditLog"
- TypePaymentOrder = "PaymentOrder"
- TypePaymentProviderInstance = "PaymentProviderInstance"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeSubscriptionPlan = "SubscriptionPlan"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -6887,6 +6895,1845 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AnnouncementRead edge %s", name)
}
+// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph.
+type AuthIdentityMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ verified_at *time.Time
+ issuer *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ channels map[int64]struct{}
+ removedchannels map[int64]struct{}
+ clearedchannels bool
+ adoption_decisions map[int64]struct{}
+ removedadoption_decisions map[int64]struct{}
+ clearedadoption_decisions bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentity, error)
+ predicates []predicate.AuthIdentity
+}
+
+var _ ent.Mutation = (*AuthIdentityMutation)(nil)
+
+// authidentityOption allows management of the mutation configuration using functional options.
+type authidentityOption func(*AuthIdentityMutation)
+
+// newAuthIdentityMutation creates new mutation for the AuthIdentity entity.
+func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation {
+ m := &AuthIdentityMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentity,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityID sets the ID field of the mutation.
+func withAuthIdentityID(id int64) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentity
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentity, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentity.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentity sets the old AuthIdentity of the mutation.
+func withAuthIdentity(node *AuthIdentity) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentity, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetUserID sets the "user_id" field.
+func (m *AuthIdentityMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *AuthIdentityMutation) UserID() (r int64, exists bool) {
+ v := m.user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserID returns the old "user_id" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ }
+ return oldValue.UserID, nil
+}
+
+// ResetUserID resets all changes to the "user_id" field.
+func (m *AuthIdentityMutation) ResetUserID() {
+ m.user = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *AuthIdentityMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *AuthIdentityMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) {
+ m.verified_at = &t
+}
+
+// VerifiedAt returns the value of the "verified_at" field in the mutation.
+func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) {
+ v := m.verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err)
+ }
+ return oldValue.VerifiedAt, nil
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (m *AuthIdentityMutation) ClearVerifiedAt() {
+ m.verified_at = nil
+ m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{}
+}
+
+// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation.
+func (m *AuthIdentityMutation) VerifiedAtCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldVerifiedAt]
+ return ok
+}
+
+// ResetVerifiedAt resets all changes to the "verified_at" field.
+func (m *AuthIdentityMutation) ResetVerifiedAt() {
+ m.verified_at = nil
+ delete(m.clearedFields, authidentity.FieldVerifiedAt)
+}
+
+// SetIssuer sets the "issuer" field.
+func (m *AuthIdentityMutation) SetIssuer(s string) {
+ m.issuer = &s
+}
+
+// Issuer returns the value of the "issuer" field in the mutation.
+func (m *AuthIdentityMutation) Issuer() (r string, exists bool) {
+ v := m.issuer
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIssuer is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIssuer requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIssuer: %w", err)
+ }
+ return oldValue.Issuer, nil
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (m *AuthIdentityMutation) ClearIssuer() {
+ m.issuer = nil
+ m.clearedFields[authidentity.FieldIssuer] = struct{}{}
+}
+
+// IssuerCleared returns if the "issuer" field was cleared in this mutation.
+func (m *AuthIdentityMutation) IssuerCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldIssuer]
+ return ok
+}
+
+// ResetIssuer resets all changes to the "issuer" field.
+func (m *AuthIdentityMutation) ResetIssuer() {
+ m.issuer = nil
+ delete(m.clearedFields, authidentity.FieldIssuer)
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *AuthIdentityMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[authidentity.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *AuthIdentityMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *AuthIdentityMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids.
+func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) {
+ if m.channels == nil {
+ m.channels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.channels[ids[i]] = struct{}{}
+ }
+}
+
+// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) ClearChannels() {
+ m.clearedchannels = true
+}
+
+// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared.
+func (m *AuthIdentityMutation) ChannelsCleared() bool {
+ return m.clearedchannels
+}
+
+// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) {
+ if m.removedchannels == nil {
+ m.removedchannels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.channels, ids[i])
+ m.removedchannels[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) {
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ChannelsIDs returns the "channels" edge IDs in the mutation.
+func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) {
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetChannels resets all changes to the "channels" edge.
+func (m *AuthIdentityMutation) ResetChannels() {
+ m.channels = nil
+ m.clearedchannels = false
+ m.removedchannels = nil
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids.
+func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) {
+ if m.adoption_decisions == nil {
+ m.adoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.adoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) ClearAdoptionDecisions() {
+ m.clearedadoption_decisions = true
+}
+
+// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool {
+ return m.clearedadoption_decisions
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) {
+ if m.removedadoption_decisions == nil {
+ m.removedadoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.adoption_decisions, ids[i])
+ m.removedadoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation.
+func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge.
+func (m *AuthIdentityMutation) ResetAdoptionDecisions() {
+ m.adoption_decisions = nil
+ m.clearedadoption_decisions = false
+ m.removedadoption_decisions = nil
+}
+
+// Where appends a list predicates to the AuthIdentityMutation builder.
+func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentity, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentity).
+func (m *AuthIdentityMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentity.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentity.FieldUpdatedAt)
+ }
+ if m.user != nil {
+ fields = append(fields, authidentity.FieldUserID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentity.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentity.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, authidentity.FieldProviderSubject)
+ }
+ if m.verified_at != nil {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.issuer != nil {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentity.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentity.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentity.FieldUserID:
+ return m.UserID()
+ case authidentity.FieldProviderType:
+ return m.ProviderType()
+ case authidentity.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentity.FieldProviderSubject:
+ return m.ProviderSubject()
+ case authidentity.FieldVerifiedAt:
+ return m.VerifiedAt()
+ case authidentity.FieldIssuer:
+ return m.Issuer()
+ case authidentity.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentity.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentity.FieldUserID:
+ return m.OldUserID(ctx)
+ case authidentity.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentity.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentity.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case authidentity.FieldVerifiedAt:
+ return m.OldVerifiedAt(ctx)
+ case authidentity.FieldIssuer:
+ return m.OldIssuer(ctx)
+ case authidentity.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentity.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentity.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case authidentity.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentity.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentity.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case authidentity.FieldVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetVerifiedAt(v)
+ return nil
+ case authidentity.FieldIssuer:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIssuer(v)
+ return nil
+ case authidentity.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentity numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(authidentity.FieldVerifiedAt) {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.FieldCleared(authidentity.FieldIssuer) {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ClearField(name string) error {
+ switch name {
+ case authidentity.FieldVerifiedAt:
+ m.ClearVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ClearIssuer()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ResetField(name string) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentity.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentity.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case authidentity.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentity.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentity.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case authidentity.FieldVerifiedAt:
+ m.ResetVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ResetIssuer()
+ return nil
+ case authidentity.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.user != nil {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.channels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.adoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.channels))
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.adoption_decisions))
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedchannels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.removedadoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.removedchannels))
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.removedadoption_decisions))
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.cleareduser {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.clearedchannels {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.clearedadoption_decisions {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentity.EdgeUser:
+ return m.cleareduser
+ case authidentity.EdgeChannels:
+ return m.clearedchannels
+ case authidentity.EdgeAdoptionDecisions:
+ return m.clearedadoption_decisions
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ResetUser()
+ return nil
+ case authidentity.EdgeChannels:
+ m.ResetChannels()
+ return nil
+ case authidentity.EdgeAdoptionDecisions:
+ m.ResetAdoptionDecisions()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity edge %s", name)
+}
+
+// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph.
+type AuthIdentityChannelMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ channel *string
+ channel_app_id *string
+ channel_subject *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentityChannel, error)
+ predicates []predicate.AuthIdentityChannel
+}
+
+var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil)
+
+// authidentitychannelOption allows management of the mutation configuration using functional options.
+type authidentitychannelOption func(*AuthIdentityChannelMutation)
+
+// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity.
+func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation {
+ m := &AuthIdentityChannelMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentityChannel,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityChannelID sets the ID field of the mutation.
+func withAuthIdentityChannelID(id int64) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentityChannel
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentityChannel.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation.
+func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentityChannel, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityChannelMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityChannelMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityChannelMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityChannelMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *AuthIdentityChannelMutation) ResetIdentityID() {
+ m.identity = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityChannelMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityChannelMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityChannelMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityChannelMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetChannel sets the "channel" field.
+func (m *AuthIdentityChannelMutation) SetChannel(s string) {
+ m.channel = &s
+}
+
+// Channel returns the value of the "channel" field in the mutation.
+func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) {
+ v := m.channel
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannel: %w", err)
+ }
+ return oldValue.Channel, nil
+}
+
+// ResetChannel resets all changes to the "channel" field.
+func (m *AuthIdentityChannelMutation) ResetChannel() {
+ m.channel = nil
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) {
+ m.channel_app_id = &s
+}
+
+// ChannelAppID returns the value of the "channel_app_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) {
+ v := m.channel_app_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelAppID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err)
+ }
+ return oldValue.ChannelAppID, nil
+}
+
+// ResetChannelAppID resets all changes to the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) ResetChannelAppID() {
+ m.channel_app_id = nil
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) {
+ m.channel_subject = &s
+}
+
+// ChannelSubject returns the value of the "channel_subject" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) {
+ v := m.channel_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err)
+ }
+ return oldValue.ChannelSubject, nil
+}
+
+// ResetChannelSubject resets all changes to the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) ResetChannelSubject() {
+ m.channel_subject = nil
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityChannelMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *AuthIdentityChannelMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *AuthIdentityChannelMutation) IdentityCleared() bool {
+ return m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *AuthIdentityChannelMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the AuthIdentityChannelMutation builder.
+func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentityChannel, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityChannelMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityChannelMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentityChannel).
+func (m *AuthIdentityChannelMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityChannelMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentitychannel.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentitychannel.FieldUpdatedAt)
+ }
+ if m.identity != nil {
+ fields = append(fields, authidentitychannel.FieldIdentityID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentitychannel.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentitychannel.FieldProviderKey)
+ }
+ if m.channel != nil {
+ fields = append(fields, authidentitychannel.FieldChannel)
+ }
+ if m.channel_app_id != nil {
+ fields = append(fields, authidentitychannel.FieldChannelAppID)
+ }
+ if m.channel_subject != nil {
+ fields = append(fields, authidentitychannel.FieldChannelSubject)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentitychannel.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentitychannel.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentitychannel.FieldIdentityID:
+ return m.IdentityID()
+ case authidentitychannel.FieldProviderType:
+ return m.ProviderType()
+ case authidentitychannel.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentitychannel.FieldChannel:
+ return m.Channel()
+ case authidentitychannel.FieldChannelAppID:
+ return m.ChannelAppID()
+ case authidentitychannel.FieldChannelSubject:
+ return m.ChannelSubject()
+ case authidentitychannel.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentitychannel.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentitychannel.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case authidentitychannel.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentitychannel.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentitychannel.FieldChannel:
+ return m.OldChannel(ctx)
+ case authidentitychannel.FieldChannelAppID:
+ return m.OldChannelAppID(ctx)
+ case authidentitychannel.FieldChannelSubject:
+ return m.OldChannelSubject(ctx)
+ case authidentitychannel.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case authidentitychannel.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentitychannel.FieldChannel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannel(v)
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelAppID(v)
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelSubject(v)
+ return nil
+ case authidentitychannel.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityChannelMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityChannelMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetField(name string) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case authidentitychannel.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentitychannel.FieldChannel:
+ m.ResetChannel()
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ m.ResetChannelAppID()
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ m.ResetChannelSubject()
+ return nil
+ case authidentitychannel.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityChannelMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.identity != nil {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityChannelMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedidentity {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel edge %s", name)
+}
+
// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
type ErrorPassthroughRuleMutation struct {
config
@@ -12191,6 +14038,781 @@ func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
}
+// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph.
+type IdentityAdoptionDecisionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ adopt_display_name *bool
+ adopt_avatar *bool
+ decided_at *time.Time
+ clearedFields map[string]struct{}
+ pending_auth_session *int64
+ clearedpending_auth_session bool
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*IdentityAdoptionDecision, error)
+ predicates []predicate.IdentityAdoptionDecision
+}
+
+var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil)
+
+// identityadoptiondecisionOption allows management of the mutation configuration using functional options.
+type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation)
+
+// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity.
+func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation {
+ m := &IdentityAdoptionDecisionMutation{
+ config: c,
+ op: op,
+ typ: TypeIdentityAdoptionDecision,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withIdentityAdoptionDecisionID sets the ID field of the mutation.
+func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *IdentityAdoptionDecision
+ )
+ m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation.
+func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m IdentityAdoptionDecisionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) {
+ m.pending_auth_session = &i
+}
+
+// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) {
+ v := m.pending_auth_session
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err)
+ }
+ return oldValue.PendingAuthSessionID, nil
+}
+
+// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() {
+ m.pending_auth_session = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() {
+ m.identity = nil
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool {
+ _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID]
+ return ok
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() {
+ m.identity = nil
+ delete(m.clearedFields, identityadoptiondecision.FieldIdentityID)
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) {
+ m.adopt_display_name = &b
+}
+
+// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) {
+ v := m.adopt_display_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err)
+ }
+ return oldValue.AdoptDisplayName, nil
+}
+
+// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() {
+ m.adopt_display_name = nil
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) {
+ m.adopt_avatar = &b
+}
+
+// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) {
+ v := m.adopt_avatar
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptAvatar requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err)
+ }
+ return oldValue.AdoptAvatar, nil
+}
+
+// ResetAdoptAvatar resets all changes to the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() {
+ m.adopt_avatar = nil
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) {
+ m.decided_at = &t
+}
+
+// DecidedAt returns the value of the "decided_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) {
+ v := m.decided_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDecidedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err)
+ }
+ return oldValue.DecidedAt, nil
+}
+
+// ResetDecidedAt resets all changes to the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() {
+ m.decided_at = nil
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() {
+ m.clearedpending_auth_session = true
+ m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{}
+}
+
+// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool {
+ return m.clearedpending_auth_session
+}
+
+// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// PendingAuthSessionID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) {
+ if id := m.pending_auth_session; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() {
+ m.pending_auth_session = nil
+ m.clearedpending_auth_session = false
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool {
+ return m.IdentityIDCleared() || m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder.
+func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdentityAdoptionDecision, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *IdentityAdoptionDecisionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (IdentityAdoptionDecision).
+func (m *IdentityAdoptionDecisionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdentityAdoptionDecisionMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.created_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldUpdatedAt)
+ }
+ if m.pending_auth_session != nil {
+ fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if m.identity != nil {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ if m.adopt_display_name != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName)
+ }
+ if m.adopt_avatar != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptAvatar)
+ }
+ if m.decided_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldDecidedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.CreatedAt()
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.PendingAuthSessionID()
+ case identityadoptiondecision.FieldIdentityID:
+ return m.IdentityID()
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.AdoptDisplayName()
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.AdoptAvatar()
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.DecidedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.OldPendingAuthSessionID(ctx)
+ case identityadoptiondecision.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.OldAdoptDisplayName(ctx)
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.OldAdoptAvatar(ctx)
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.OldDecidedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPendingAuthSessionID(v)
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptDisplayName(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptAvatar(v)
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDecidedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(identityadoptiondecision.FieldIdentityID) {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldIdentityID:
+ m.ClearIdentityID()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ m.ResetPendingAuthSessionID()
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ m.ResetAdoptDisplayName()
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ m.ResetAdoptAvatar()
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ m.ResetDecidedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.pending_auth_session != nil {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.identity != nil {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ if id := m.pending_auth_session; id != nil {
+ return []ent.Value{*id}
+ }
+ case identityadoptiondecision.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedpending_auth_session {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.clearedidentity {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ return m.clearedpending_auth_session
+ case identityadoptiondecision.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ClearPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ResetPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name)
+}
+
// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph.
type PaymentAuditLogMutation struct {
config
@@ -12763,6 +15385,8 @@ type PaymentOrderMutation struct {
subscription_days *int
addsubscription_days *int
provider_instance_id *string
+ provider_key *string
+ provider_snapshot *map[string]interface{}
status *string
refund_amount *float64
addrefund_amount *float64
@@ -13799,6 +16423,104 @@ func (m *PaymentOrderMutation) ResetProviderInstanceID() {
delete(m.clearedFields, paymentorder.FieldProviderInstanceID)
}
+// SetProviderKey sets the "provider_key" field.
+func (m *PaymentOrderMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (m *PaymentOrderMutation) ClearProviderKey() {
+ m.provider_key = nil
+ m.clearedFields[paymentorder.FieldProviderKey] = struct{}{}
+}
+
+// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderKeyCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderKey]
+ return ok
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PaymentOrderMutation) ResetProviderKey() {
+ m.provider_key = nil
+ delete(m.clearedFields, paymentorder.FieldProviderKey)
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) {
+ m.provider_snapshot = &value
+}
+
+// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation.
+func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) {
+ v := m.provider_snapshot
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSnapshot requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err)
+ }
+ return oldValue.ProviderSnapshot, nil
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ClearProviderSnapshot() {
+ m.provider_snapshot = nil
+ m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{}
+}
+
+// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot]
+ return ok
+}
+
+// ResetProviderSnapshot resets all changes to the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ResetProviderSnapshot() {
+ m.provider_snapshot = nil
+ delete(m.clearedFields, paymentorder.FieldProviderSnapshot)
+}
+
// SetStatus sets the "status" field.
func (m *PaymentOrderMutation) SetStatus(s string) {
m.status = &s
@@ -14658,7 +17380,7 @@ func (m *PaymentOrderMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentOrderMutation) Fields() []string {
- fields := make([]string, 0, 37)
+ fields := make([]string, 0, 39)
if m.user != nil {
fields = append(fields, paymentorder.FieldUserID)
}
@@ -14716,6 +17438,12 @@ func (m *PaymentOrderMutation) Fields() []string {
if m.provider_instance_id != nil {
fields = append(fields, paymentorder.FieldProviderInstanceID)
}
+ if m.provider_key != nil {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
+ if m.provider_snapshot != nil {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.status != nil {
fields = append(fields, paymentorder.FieldStatus)
}
@@ -14816,6 +17544,10 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
return m.SubscriptionDays()
case paymentorder.FieldProviderInstanceID:
return m.ProviderInstanceID()
+ case paymentorder.FieldProviderKey:
+ return m.ProviderKey()
+ case paymentorder.FieldProviderSnapshot:
+ return m.ProviderSnapshot()
case paymentorder.FieldStatus:
return m.Status()
case paymentorder.FieldRefundAmount:
@@ -14899,6 +17631,10 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldSubscriptionDays(ctx)
case paymentorder.FieldProviderInstanceID:
return m.OldProviderInstanceID(ctx)
+ case paymentorder.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case paymentorder.FieldProviderSnapshot:
+ return m.OldProviderSnapshot(ctx)
case paymentorder.FieldStatus:
return m.OldStatus(ctx)
case paymentorder.FieldRefundAmount:
@@ -15077,6 +17813,20 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
}
m.SetProviderInstanceID(v)
return nil
+ case paymentorder.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSnapshot(v)
+ return nil
case paymentorder.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -15344,6 +18094,12 @@ func (m *PaymentOrderMutation) ClearedFields() []string {
if m.FieldCleared(paymentorder.FieldProviderInstanceID) {
fields = append(fields, paymentorder.FieldProviderInstanceID)
}
+ if m.FieldCleared(paymentorder.FieldProviderKey) {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
+ if m.FieldCleared(paymentorder.FieldProviderSnapshot) {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.FieldCleared(paymentorder.FieldRefundReason) {
fields = append(fields, paymentorder.FieldRefundReason)
}
@@ -15412,6 +18168,12 @@ func (m *PaymentOrderMutation) ClearField(name string) error {
case paymentorder.FieldProviderInstanceID:
m.ClearProviderInstanceID()
return nil
+ case paymentorder.FieldProviderKey:
+ m.ClearProviderKey()
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ClearProviderSnapshot()
+ return nil
case paymentorder.FieldRefundReason:
m.ClearRefundReason()
return nil
@@ -15507,6 +18269,12 @@ func (m *PaymentOrderMutation) ResetField(name string) error {
case paymentorder.FieldProviderInstanceID:
m.ResetProviderInstanceID()
return nil
+ case paymentorder.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ResetProviderSnapshot()
+ return nil
case paymentorder.FieldStatus:
m.ResetStatus()
return nil
@@ -16595,6 +19363,1645 @@ func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown PaymentProviderInstance edge %s", name)
}
+// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph.
+type PendingAuthSessionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ session_token *string
+ intent *string
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ redirect_to *string
+ resolved_email *string
+ registration_password_hash *string
+ upstream_identity_claims *map[string]interface{}
+ local_flow_state *map[string]interface{}
+ browser_session_key *string
+ completion_code_hash *string
+ completion_code_expires_at *time.Time
+ email_verified_at *time.Time
+ password_verified_at *time.Time
+ totp_verified_at *time.Time
+ expires_at *time.Time
+ consumed_at *time.Time
+ clearedFields map[string]struct{}
+ target_user *int64
+ clearedtarget_user bool
+ adoption_decision *int64
+ clearedadoption_decision bool
+ done bool
+ oldValue func(context.Context) (*PendingAuthSession, error)
+ predicates []predicate.PendingAuthSession
+}
+
+var _ ent.Mutation = (*PendingAuthSessionMutation)(nil)
+
+// pendingauthsessionOption allows management of the mutation configuration using functional options.
+type pendingauthsessionOption func(*PendingAuthSessionMutation)
+
+// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity.
+func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation {
+ m := &PendingAuthSessionMutation{
+ config: c,
+ op: op,
+ typ: TypePendingAuthSession,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPendingAuthSessionID sets the ID field of the mutation.
+func withPendingAuthSessionID(id int64) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PendingAuthSession
+ )
+ m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PendingAuthSession.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPendingAuthSession sets the old PendingAuthSession of the mutation.
+func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ m.oldValue = func(context.Context) (*PendingAuthSession, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PendingAuthSessionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PendingAuthSessionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PendingAuthSessionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PendingAuthSessionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetSessionToken sets the "session_token" field.
+func (m *PendingAuthSessionMutation) SetSessionToken(s string) {
+ m.session_token = &s
+}
+
+// SessionToken returns the value of the "session_token" field in the mutation.
+func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) {
+ v := m.session_token
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSessionToken is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSessionToken requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSessionToken: %w", err)
+ }
+ return oldValue.SessionToken, nil
+}
+
+// ResetSessionToken resets all changes to the "session_token" field.
+func (m *PendingAuthSessionMutation) ResetSessionToken() {
+ m.session_token = nil
+}
+
+// SetIntent sets the "intent" field.
+func (m *PendingAuthSessionMutation) SetIntent(s string) {
+ m.intent = &s
+}
+
+// Intent returns the value of the "intent" field in the mutation.
+func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) {
+ v := m.intent
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntent returns the old "intent" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntent: %w", err)
+ }
+ return oldValue.Intent, nil
+}
+
+// ResetIntent resets all changes to the "intent" field.
+func (m *PendingAuthSessionMutation) ResetIntent() {
+ m.intent = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *PendingAuthSessionMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *PendingAuthSessionMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PendingAuthSessionMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PendingAuthSessionMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *PendingAuthSessionMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *PendingAuthSessionMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) {
+ m.target_user = &i
+}
+
+// TargetUserID returns the value of the "target_user_id" field in the mutation.
+func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) {
+ v := m.target_user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTargetUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err)
+ }
+ return oldValue.TargetUserID, nil
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ClearTargetUserID() {
+ m.target_user = nil
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID]
+ return ok
+}
+
+// ResetTargetUserID resets all changes to the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ResetTargetUserID() {
+ m.target_user = nil
+ delete(m.clearedFields, pendingauthsession.FieldTargetUserID)
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (m *PendingAuthSessionMutation) SetRedirectTo(s string) {
+ m.redirect_to = &s
+}
+
+// RedirectTo returns the value of the "redirect_to" field in the mutation.
+func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) {
+ v := m.redirect_to
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRedirectTo requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err)
+ }
+ return oldValue.RedirectTo, nil
+}
+
+// ResetRedirectTo resets all changes to the "redirect_to" field.
+func (m *PendingAuthSessionMutation) ResetRedirectTo() {
+ m.redirect_to = nil
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) {
+ m.resolved_email = &s
+}
+
+// ResolvedEmail returns the value of the "resolved_email" field in the mutation.
+func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) {
+ v := m.resolved_email
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResolvedEmail requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err)
+ }
+ return oldValue.ResolvedEmail, nil
+}
+
+// ResetResolvedEmail resets all changes to the "resolved_email" field.
+func (m *PendingAuthSessionMutation) ResetResolvedEmail() {
+ m.resolved_email = nil
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) {
+ m.registration_password_hash = &s
+}
+
+// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) {
+ v := m.registration_password_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err)
+ }
+ return oldValue.RegistrationPasswordHash, nil
+}
+
+// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() {
+ m.registration_password_hash = nil
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) {
+ m.upstream_identity_claims = &value
+}
+
+// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation.
+func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) {
+ v := m.upstream_identity_claims
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err)
+ }
+ return oldValue.UpstreamIdentityClaims, nil
+}
+
+// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() {
+ m.upstream_identity_claims = nil
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) {
+ m.local_flow_state = &value
+}
+
+// LocalFlowState returns the value of the "local_flow_state" field in the mutation.
+func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) {
+ v := m.local_flow_state
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLocalFlowState requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err)
+ }
+ return oldValue.LocalFlowState, nil
+}
+
+// ResetLocalFlowState resets all changes to the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) ResetLocalFlowState() {
+ m.local_flow_state = nil
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) {
+ m.browser_session_key = &s
+}
+
+// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation.
+func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) {
+ v := m.browser_session_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err)
+ }
+ return oldValue.BrowserSessionKey, nil
+}
+
+// ResetBrowserSessionKey resets all changes to the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() {
+ m.browser_session_key = nil
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) {
+ m.completion_code_hash = &s
+}
+
+// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) {
+ v := m.completion_code_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err)
+ }
+ return oldValue.CompletionCodeHash, nil
+}
+
+// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() {
+ m.completion_code_hash = nil
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) {
+ m.completion_code_expires_at = &t
+}
+
+// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) {
+ v := m.completion_code_expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err)
+ }
+ return oldValue.CompletionCodeExpiresAt, nil
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{}
+}
+
+// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt]
+ return ok
+}
+
+// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt)
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) {
+ m.email_verified_at = &t
+}
+
+// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) {
+ v := m.email_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err)
+ }
+ return oldValue.EmailVerifiedAt, nil
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() {
+ m.email_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{}
+}
+
+// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt]
+ return ok
+}
+
+// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() {
+ m.email_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt)
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) {
+ m.password_verified_at = &t
+}
+
+// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) {
+ v := m.password_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err)
+ }
+ return oldValue.PasswordVerifiedAt, nil
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{}
+}
+
+// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt]
+ return ok
+}
+
+// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt)
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) {
+ m.totp_verified_at = &t
+}
+
+// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) {
+ v := m.totp_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err)
+ }
+ return oldValue.TotpVerifiedAt, nil
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{}
+}
+
+// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt]
+ return ok
+}
+
+// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PendingAuthSessionMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) {
+ m.consumed_at = &t
+}
+
+// ConsumedAt returns the value of the "consumed_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) {
+ v := m.consumed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldConsumedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err)
+ }
+ return oldValue.ConsumedAt, nil
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ClearConsumedAt() {
+ m.consumed_at = nil
+ m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{}
+}
+
+// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt]
+ return ok
+}
+
+// ResetConsumedAt resets all changes to the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ResetConsumedAt() {
+ m.consumed_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldConsumedAt)
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (m *PendingAuthSessionMutation) ClearTargetUser() {
+ m.clearedtarget_user = true
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserCleared reports if the "target_user" edge to the User entity was cleared.
+func (m *PendingAuthSessionMutation) TargetUserCleared() bool {
+ return m.TargetUserIDCleared() || m.clearedtarget_user
+}
+
+// TargetUserIDs returns the "target_user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// TargetUserID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) {
+ if id := m.target_user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetTargetUser resets all changes to the "target_user" edge.
+func (m *PendingAuthSessionMutation) ResetTargetUser() {
+ m.target_user = nil
+ m.clearedtarget_user = false
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id.
+func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) {
+ m.adoption_decision = &id
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (m *PendingAuthSessionMutation) ClearAdoptionDecision() {
+ m.clearedadoption_decision = true
+}
+
+// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool {
+ return m.clearedadoption_decision
+}
+
+// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation.
+func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) {
+ if m.adoption_decision != nil {
+ return *m.adoption_decision, true
+ }
+ return
+}
+
+// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// AdoptionDecisionID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) {
+ if id := m.adoption_decision; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetAdoptionDecision resets all changes to the "adoption_decision" edge.
+func (m *PendingAuthSessionMutation) ResetAdoptionDecision() {
+ m.adoption_decision = nil
+ m.clearedadoption_decision = false
+}
+
+// Where appends a list predicates to the PendingAuthSessionMutation builder.
+func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PendingAuthSession, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PendingAuthSessionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PendingAuthSessionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PendingAuthSession).
+func (m *PendingAuthSessionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PendingAuthSessionMutation) Fields() []string {
+ fields := make([]string, 0, 21)
+ if m.created_at != nil {
+ fields = append(fields, pendingauthsession.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, pendingauthsession.FieldUpdatedAt)
+ }
+ if m.session_token != nil {
+ fields = append(fields, pendingauthsession.FieldSessionToken)
+ }
+ if m.intent != nil {
+ fields = append(fields, pendingauthsession.FieldIntent)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, pendingauthsession.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, pendingauthsession.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, pendingauthsession.FieldProviderSubject)
+ }
+ if m.target_user != nil {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.redirect_to != nil {
+ fields = append(fields, pendingauthsession.FieldRedirectTo)
+ }
+ if m.resolved_email != nil {
+ fields = append(fields, pendingauthsession.FieldResolvedEmail)
+ }
+ if m.registration_password_hash != nil {
+ fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash)
+ }
+ if m.upstream_identity_claims != nil {
+ fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims)
+ }
+ if m.local_flow_state != nil {
+ fields = append(fields, pendingauthsession.FieldLocalFlowState)
+ }
+ if m.browser_session_key != nil {
+ fields = append(fields, pendingauthsession.FieldBrowserSessionKey)
+ }
+ if m.completion_code_hash != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeHash)
+ }
+ if m.completion_code_expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.email_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.password_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.totp_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldExpiresAt)
+ }
+ if m.consumed_at != nil {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.CreatedAt()
+ case pendingauthsession.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case pendingauthsession.FieldSessionToken:
+ return m.SessionToken()
+ case pendingauthsession.FieldIntent:
+ return m.Intent()
+ case pendingauthsession.FieldProviderType:
+ return m.ProviderType()
+ case pendingauthsession.FieldProviderKey:
+ return m.ProviderKey()
+ case pendingauthsession.FieldProviderSubject:
+ return m.ProviderSubject()
+ case pendingauthsession.FieldTargetUserID:
+ return m.TargetUserID()
+ case pendingauthsession.FieldRedirectTo:
+ return m.RedirectTo()
+ case pendingauthsession.FieldResolvedEmail:
+ return m.ResolvedEmail()
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.RegistrationPasswordHash()
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.UpstreamIdentityClaims()
+ case pendingauthsession.FieldLocalFlowState:
+ return m.LocalFlowState()
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.BrowserSessionKey()
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.CompletionCodeHash()
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.CompletionCodeExpiresAt()
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.EmailVerifiedAt()
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.PasswordVerifiedAt()
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.TotpVerifiedAt()
+ case pendingauthsession.FieldExpiresAt:
+ return m.ExpiresAt()
+ case pendingauthsession.FieldConsumedAt:
+ return m.ConsumedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case pendingauthsession.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case pendingauthsession.FieldSessionToken:
+ return m.OldSessionToken(ctx)
+ case pendingauthsession.FieldIntent:
+ return m.OldIntent(ctx)
+ case pendingauthsession.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case pendingauthsession.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case pendingauthsession.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case pendingauthsession.FieldTargetUserID:
+ return m.OldTargetUserID(ctx)
+ case pendingauthsession.FieldRedirectTo:
+ return m.OldRedirectTo(ctx)
+ case pendingauthsession.FieldResolvedEmail:
+ return m.OldResolvedEmail(ctx)
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.OldRegistrationPasswordHash(ctx)
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.OldUpstreamIdentityClaims(ctx)
+ case pendingauthsession.FieldLocalFlowState:
+ return m.OldLocalFlowState(ctx)
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.OldBrowserSessionKey(ctx)
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.OldCompletionCodeHash(ctx)
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.OldCompletionCodeExpiresAt(ctx)
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.OldEmailVerifiedAt(ctx)
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.OldPasswordVerifiedAt(ctx)
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.OldTotpVerifiedAt(ctx)
+ case pendingauthsession.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case pendingauthsession.FieldConsumedAt:
+ return m.OldConsumedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSessionToken(v)
+ return nil
+ case pendingauthsession.FieldIntent:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIntent(v)
+ return nil
+ case pendingauthsession.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTargetUserID(v)
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRedirectTo(v)
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResolvedEmail(v)
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRegistrationPasswordHash(v)
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpstreamIdentityClaims(v)
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLocalFlowState(v)
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBrowserSessionKey(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeHash(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEmailVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPasswordVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConsumedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PendingAuthSessionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown PendingAuthSession numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PendingAuthSessionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(pendingauthsession.FieldTargetUserID) {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldConsumedAt) {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PendingAuthSessionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearField(name string) error {
+ switch name {
+ case pendingauthsession.FieldTargetUserID:
+ m.ClearTargetUserID()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ClearCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ClearEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ClearPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ClearTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ClearConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetField(name string) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ m.ResetSessionToken()
+ return nil
+ case pendingauthsession.FieldIntent:
+ m.ResetIntent()
+ return nil
+ case pendingauthsession.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ m.ResetTargetUserID()
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ m.ResetRedirectTo()
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ m.ResetResolvedEmail()
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ m.ResetRegistrationPasswordHash()
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ m.ResetUpstreamIdentityClaims()
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ m.ResetLocalFlowState()
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ m.ResetBrowserSessionKey()
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ m.ResetCompletionCodeHash()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ResetCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ResetEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ResetPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ResetTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ResetConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PendingAuthSessionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.target_user != nil {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.adoption_decision != nil {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ if id := m.target_user; id != nil {
+ return []ent.Value{*id}
+ }
+ case pendingauthsession.EdgeAdoptionDecision:
+ if id := m.adoption_decision; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PendingAuthSessionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PendingAuthSessionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedtarget_user {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.clearedadoption_decision {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ return m.clearedtarget_user
+ case pendingauthsession.EdgeAdoptionDecision:
+ return m.clearedadoption_decision
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ClearTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ClearAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ResetTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ResetAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession edge %s", name)
+}
+
// PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph.
type PromoCodeMutation struct {
config
@@ -28264,6 +32671,9 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ signup_source *string
+ last_login_at *time.Time
+ last_active_at *time.Time
balance_notify_enabled *bool
balance_notify_threshold_type *string
balance_notify_threshold *float64
@@ -28302,6 +32712,12 @@ type UserMutation struct {
payment_orders map[int64]struct{}
removedpayment_orders map[int64]struct{}
clearedpayment_orders bool
+ auth_identities map[int64]struct{}
+ removedauth_identities map[int64]struct{}
+ clearedauth_identities bool
+ pending_auth_sessions map[int64]struct{}
+ removedpending_auth_sessions map[int64]struct{}
+ clearedpending_auth_sessions bool
done bool
oldValue func(context.Context) (*User, error)
predicates []predicate.User
@@ -28988,6 +33404,140 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetSignupSource sets the "signup_source" field.
+func (m *UserMutation) SetSignupSource(s string) {
+ m.signup_source = &s
+}
+
+// SignupSource returns the value of the "signup_source" field in the mutation.
+func (m *UserMutation) SignupSource() (r string, exists bool) {
+ v := m.signup_source
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSignupSource returns the old "signup_source" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSignupSource is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSignupSource requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSignupSource: %w", err)
+ }
+ return oldValue.SignupSource, nil
+}
+
+// ResetSignupSource resets all changes to the "signup_source" field.
+func (m *UserMutation) ResetSignupSource() {
+ m.signup_source = nil
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (m *UserMutation) SetLastLoginAt(t time.Time) {
+ m.last_login_at = &t
+}
+
+// LastLoginAt returns the value of the "last_login_at" field in the mutation.
+func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) {
+ v := m.last_login_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastLoginAt returns the old "last_login_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastLoginAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err)
+ }
+ return oldValue.LastLoginAt, nil
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (m *UserMutation) ClearLastLoginAt() {
+ m.last_login_at = nil
+ m.clearedFields[user.FieldLastLoginAt] = struct{}{}
+}
+
+// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation.
+func (m *UserMutation) LastLoginAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastLoginAt]
+ return ok
+}
+
+// ResetLastLoginAt resets all changes to the "last_login_at" field.
+func (m *UserMutation) ResetLastLoginAt() {
+ m.last_login_at = nil
+ delete(m.clearedFields, user.FieldLastLoginAt)
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (m *UserMutation) SetLastActiveAt(t time.Time) {
+ m.last_active_at = &t
+}
+
+// LastActiveAt returns the value of the "last_active_at" field in the mutation.
+func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) {
+ v := m.last_active_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastActiveAt returns the old "last_active_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastActiveAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err)
+ }
+ return oldValue.LastActiveAt, nil
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (m *UserMutation) ClearLastActiveAt() {
+ m.last_active_at = nil
+ m.clearedFields[user.FieldLastActiveAt] = struct{}{}
+}
+
+// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation.
+func (m *UserMutation) LastActiveAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastActiveAt]
+ return ok
+}
+
+// ResetLastActiveAt resets all changes to the "last_active_at" field.
+func (m *UserMutation) ResetLastActiveAt() {
+ m.last_active_at = nil
+ delete(m.clearedFields, user.FieldLastActiveAt)
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
m.balance_notify_enabled = &b
@@ -29762,6 +34312,114 @@ func (m *UserMutation) ResetPaymentOrders() {
m.removedpayment_orders = nil
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids.
+func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) {
+ if m.auth_identities == nil {
+ m.auth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.auth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) ClearAuthIdentities() {
+ m.clearedauth_identities = true
+}
+
+// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared.
+func (m *UserMutation) AuthIdentitiesCleared() bool {
+ return m.clearedauth_identities
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) {
+ if m.removedauth_identities == nil {
+ m.removedauth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.auth_identities, ids[i])
+ m.removedauth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) {
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation.
+func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) {
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAuthIdentities resets all changes to the "auth_identities" edge.
+func (m *UserMutation) ResetAuthIdentities() {
+ m.auth_identities = nil
+ m.clearedauth_identities = false
+ m.removedauth_identities = nil
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids.
+func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) {
+ if m.pending_auth_sessions == nil {
+ m.pending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.pending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) ClearPendingAuthSessions() {
+ m.clearedpending_auth_sessions = true
+}
+
+// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared.
+func (m *UserMutation) PendingAuthSessionsCleared() bool {
+ return m.clearedpending_auth_sessions
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) {
+ if m.removedpending_auth_sessions == nil {
+ m.removedpending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.pending_auth_sessions, ids[i])
+ m.removedpending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation.
+func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge.
+func (m *UserMutation) ResetPendingAuthSessions() {
+ m.pending_auth_sessions = nil
+ m.clearedpending_auth_sessions = false
+ m.removedpending_auth_sessions = nil
+}
+
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -29796,7 +34454,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 19)
+ fields := make([]string, 0, 22)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29839,6 +34497,15 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.signup_source != nil {
+ fields = append(fields, user.FieldSignupSource)
+ }
+ if m.last_login_at != nil {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.last_active_at != nil {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
@@ -29890,6 +34557,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldSignupSource:
+ return m.SignupSource()
+ case user.FieldLastLoginAt:
+ return m.LastLoginAt()
+ case user.FieldLastActiveAt:
+ return m.LastActiveAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
case user.FieldBalanceNotifyThresholdType:
@@ -29937,6 +34610,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldSignupSource:
+ return m.OldSignupSource(ctx)
+ case user.FieldLastLoginAt:
+ return m.OldLastLoginAt(ctx)
+ case user.FieldLastActiveAt:
+ return m.OldLastActiveAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
case user.FieldBalanceNotifyThresholdType:
@@ -30054,6 +34733,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldSignupSource:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSignupSource(v)
+ return nil
+ case user.FieldLastLoginAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastLoginAt(v)
+ return nil
+ case user.FieldLastActiveAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastActiveAt(v)
+ return nil
case user.FieldBalanceNotifyEnabled:
v, ok := value.(bool)
if !ok {
@@ -30179,6 +34879,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldLastLoginAt) {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.FieldCleared(user.FieldLastActiveAt) {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
@@ -30205,6 +34911,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldLastLoginAt:
+ m.ClearLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ClearLastActiveAt()
+ return nil
case user.FieldBalanceNotifyThreshold:
m.ClearBalanceNotifyThreshold()
return nil
@@ -30258,6 +34970,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldSignupSource:
+ m.ResetSignupSource()
+ return nil
+ case user.FieldLastLoginAt:
+ m.ResetLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ResetLastActiveAt()
+ return nil
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
@@ -30279,7 +35000,7 @@ func (m *UserMutation) ResetField(name string) error {
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.api_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30310,6 +35031,12 @@ func (m *UserMutation) AddedEdges() []string {
if m.payment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.auth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.pending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30377,13 +35104,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.auth_identities))
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.pending_auth_sessions))
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.removedapi_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30414,6 +35153,12 @@ func (m *UserMutation) RemovedEdges() []string {
if m.removedpayment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.removedauth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.removedpending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30481,13 +35226,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.removedauth_identities))
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions))
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.clearedapi_keys {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30518,6 +35275,12 @@ func (m *UserMutation) ClearedEdges() []string {
if m.clearedpayment_orders {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.clearedauth_identities {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.clearedpending_auth_sessions {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30545,6 +35308,10 @@ func (m *UserMutation) EdgeCleared(name string) bool {
return m.clearedpromo_code_usages
case user.EdgePaymentOrders:
return m.clearedpayment_orders
+ case user.EdgeAuthIdentities:
+ return m.clearedauth_identities
+ case user.EdgePendingAuthSessions:
+ return m.clearedpending_auth_sessions
}
return false
}
@@ -30591,6 +35358,12 @@ func (m *UserMutation) ResetEdge(name string) error {
case user.EdgePaymentOrders:
m.ResetPaymentOrders()
return nil
+ case user.EdgeAuthIdentities:
+ m.ResetAuthIdentities()
+ return nil
+ case user.EdgePendingAuthSessions:
+ m.ResetPendingAuthSessions()
+ return nil
}
return fmt.Errorf("unknown User edge %s", name)
}
diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go
index 6ea3e709..b131b8c8 100644
--- a/backend/ent/paymentorder.go
+++ b/backend/ent/paymentorder.go
@@ -3,6 +3,7 @@
package ent
import (
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -56,6 +57,10 @@ type PaymentOrder struct {
SubscriptionDays *int `json:"subscription_days,omitempty"`
// ProviderInstanceID holds the value of the "provider_instance_id" field.
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey *string `json:"provider_key,omitempty"`
+ // ProviderSnapshot holds the value of the "provider_snapshot" field.
+ ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// RefundAmount holds the value of the "refund_amount" field.
@@ -123,13 +128,15 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case paymentorder.FieldProviderSnapshot:
+ values[i] = new([]byte)
case paymentorder.FieldForceRefund:
values[i] = new(sql.NullBool)
case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount:
values[i] = new(sql.NullFloat64)
case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays:
values[i] = new(sql.NullInt64)
- case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
+ case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
values[i] = new(sql.NullString)
case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@@ -276,6 +283,21 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error {
_m.ProviderInstanceID = new(string)
*_m.ProviderInstanceID = value.String
}
+ case paymentorder.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = new(string)
+ *_m.ProviderKey = value.String
+ }
+ case paymentorder.FieldProviderSnapshot:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil {
+ return fmt.Errorf("unmarshal field provider_snapshot: %w", err)
+ }
+ }
case paymentorder.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -508,6 +530,14 @@ func (_m *PaymentOrder) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ if v := _m.ProviderKey; v != nil {
+ builder.WriteString("provider_key=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("provider_snapshot=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot))
+ builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go
index 4467b2b6..62883794 100644
--- a/backend/ent/paymentorder/paymentorder.go
+++ b/backend/ent/paymentorder/paymentorder.go
@@ -52,6 +52,10 @@ const (
FieldSubscriptionDays = "subscription_days"
// FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database.
FieldProviderInstanceID = "provider_instance_id"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database.
+ FieldProviderSnapshot = "provider_snapshot"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldRefundAmount holds the string denoting the refund_amount field in the database.
@@ -123,6 +127,8 @@ var Columns = []string{
FieldSubscriptionGroupID,
FieldSubscriptionDays,
FieldProviderInstanceID,
+ FieldProviderKey,
+ FieldProviderSnapshot,
FieldStatus,
FieldRefundAmount,
FieldRefundReason,
@@ -176,6 +182,8 @@ var (
OrderTypeValidator func(string) error
// ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
ProviderInstanceIDValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -301,6 +309,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc()
}
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go
index 78520fac..e96bf51e 100644
--- a/backend/ent/paymentorder/where.go
+++ b/backend/ent/paymentorder/where.go
@@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v))
}
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
@@ -1360,6 +1365,91 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v))
}
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field.
+func ProviderKeyIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey))
+}
+
+// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field.
+func ProviderKeyNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot))
+}
+
+// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot))
+}
+
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go
index 03098339..3ee24f8e 100644
--- a/backend/ent/paymentorder_create.go
+++ b/backend/ent/paymentorder_create.go
@@ -225,6 +225,26 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO
return _c
}
+// SetProviderKey sets the "provider_key" field.
+func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetProviderKey(*v)
+ }
+ return _c
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate {
+ _c.mutation.SetProviderSnapshot(v)
+ return _c
+}
+
// SetStatus sets the "status" field.
func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate {
_c.mutation.SetStatus(v)
@@ -602,6 +622,11 @@ func (_c *PaymentOrderCreate) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)}
}
@@ -748,6 +773,14 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec)
_spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value)
_node.ProviderInstanceID = &value
}
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = &value
+ }
+ if value, ok := _c.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ _node.ProviderSnapshot = value
+ }
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -1201,6 +1234,42 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert {
return u
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderKey)
+ return u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderKey)
+ return u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderSnapshot, v)
+ return u
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert {
u.Set(paymentorder.FieldStatus, v)
@@ -1880,6 +1949,48 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne
})
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne {
return u.Update(func(s *PaymentOrderUpsert) {
@@ -2770,6 +2881,48 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu
})
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk {
return u.Update(func(s *PaymentOrderUpsert) {
diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go
index 5978fc29..378e0dad 100644
--- a/backend/ent/paymentorder_update.go
+++ b/backend/ent/paymentorder_update.go
@@ -385,6 +385,38 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate {
return _u
}
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate {
_u.mutation.SetStatus(v)
@@ -776,6 +808,11 @@ func (_u *PaymentOrderUpdate) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Status(); ok {
if err := paymentorder.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
@@ -910,6 +947,18 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error
if _u.mutation.ProviderInstanceIDCleared() {
_spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
@@ -1399,6 +1448,38 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn
return _u
}
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne {
_u.mutation.SetStatus(v)
@@ -1803,6 +1884,11 @@ func (_u *PaymentOrderUpdateOne) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Status(); ok {
if err := paymentorder.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
@@ -1954,6 +2040,18 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd
if _u.mutation.ProviderInstanceIDCleared() {
_spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go
new file mode 100644
index 00000000..e77c065f
--- /dev/null
+++ b/backend/ent/pendingauthsession.go
@@ -0,0 +1,399 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSession is the model entity for the PendingAuthSession schema.
+type PendingAuthSession struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // SessionToken holds the value of the "session_token" field.
+ SessionToken string `json:"session_token,omitempty"`
+ // Intent holds the value of the "intent" field.
+ Intent string `json:"intent,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // TargetUserID holds the value of the "target_user_id" field.
+ TargetUserID *int64 `json:"target_user_id,omitempty"`
+ // RedirectTo holds the value of the "redirect_to" field.
+ RedirectTo string `json:"redirect_to,omitempty"`
+ // ResolvedEmail holds the value of the "resolved_email" field.
+ ResolvedEmail string `json:"resolved_email,omitempty"`
+ // RegistrationPasswordHash holds the value of the "registration_password_hash" field.
+ RegistrationPasswordHash string `json:"registration_password_hash,omitempty"`
+ // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field.
+ UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"`
+ // LocalFlowState holds the value of the "local_flow_state" field.
+ LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"`
+ // BrowserSessionKey holds the value of the "browser_session_key" field.
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
+ // CompletionCodeHash holds the value of the "completion_code_hash" field.
+ CompletionCodeHash string `json:"completion_code_hash,omitempty"`
+ // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field.
+ CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"`
+ // EmailVerifiedAt holds the value of the "email_verified_at" field.
+ EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
+ // PasswordVerifiedAt holds the value of the "password_verified_at" field.
+ PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"`
+ // TotpVerifiedAt holds the value of the "totp_verified_at" field.
+ TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // ConsumedAt holds the value of the "consumed_at" field.
+ ConsumedAt *time.Time `json:"consumed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the PendingAuthSessionQuery when eager-loading is set.
+ Edges PendingAuthSessionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph.
+type PendingAuthSessionEdges struct {
+ // TargetUser holds the value of the target_user edge.
+ TargetUser *User `json:"target_user,omitempty"`
+ // AdoptionDecision holds the value of the adoption_decision edge.
+ AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// TargetUserOrErr returns the TargetUser value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) {
+ if e.TargetUser != nil {
+ return e.TargetUser, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "target_user"}
+}
+
+// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) {
+ if e.AdoptionDecision != nil {
+ return e.AdoptionDecision, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: identityadoptiondecision.Label}
+ }
+ return nil, &NotLoadedError{edge: "adoption_decision"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PendingAuthSession) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState:
+ values[i] = new([]byte)
+ case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID:
+ values[i] = new(sql.NullInt64)
+ case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash:
+ values[i] = new(sql.NullString)
+ case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PendingAuthSession fields.
+func (_m *PendingAuthSession) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case pendingauthsession.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case pendingauthsession.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case pendingauthsession.FieldSessionToken:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field session_token", values[i])
+ } else if value.Valid {
+ _m.SessionToken = value.String
+ }
+ case pendingauthsession.FieldIntent:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field intent", values[i])
+ } else if value.Valid {
+ _m.Intent = value.String
+ }
+ case pendingauthsession.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case pendingauthsession.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case pendingauthsession.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case pendingauthsession.FieldTargetUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field target_user_id", values[i])
+ } else if value.Valid {
+ _m.TargetUserID = new(int64)
+ *_m.TargetUserID = value.Int64
+ }
+ case pendingauthsession.FieldRedirectTo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field redirect_to", values[i])
+ } else if value.Valid {
+ _m.RedirectTo = value.String
+ }
+ case pendingauthsession.FieldResolvedEmail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field resolved_email", values[i])
+ } else if value.Valid {
+ _m.ResolvedEmail = value.String
+ }
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i])
+ } else if value.Valid {
+ _m.RegistrationPasswordHash = value.String
+ }
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil {
+ return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err)
+ }
+ }
+ case pendingauthsession.FieldLocalFlowState:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field local_flow_state", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil {
+ return fmt.Errorf("unmarshal field local_flow_state: %w", err)
+ }
+ }
+ case pendingauthsession.FieldBrowserSessionKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field browser_session_key", values[i])
+ } else if value.Valid {
+ _m.BrowserSessionKey = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeHash = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeExpiresAt = new(time.Time)
+ *_m.CompletionCodeExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldEmailVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field email_verified_at", values[i])
+ } else if value.Valid {
+ _m.EmailVerifiedAt = new(time.Time)
+ *_m.EmailVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field password_verified_at", values[i])
+ } else if value.Valid {
+ _m.PasswordVerifiedAt = new(time.Time)
+ *_m.PasswordVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldTotpVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i])
+ } else if value.Valid {
+ _m.TotpVerifiedAt = new(time.Time)
+ *_m.TotpVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldConsumedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field consumed_at", values[i])
+ } else if value.Valid {
+ _m.ConsumedAt = new(time.Time)
+ *_m.ConsumedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession.
+// This includes values selected through modifiers, order, etc.
+func (_m *PendingAuthSession) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryTargetUser() *UserQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m)
+}
+
+// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m)
+}
+
+// Update returns a builder for updating this PendingAuthSession.
+// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne {
+ return NewPendingAuthSessionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PendingAuthSession) Unwrap() *PendingAuthSession {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PendingAuthSession is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PendingAuthSession) String() string {
+ var builder strings.Builder
+ builder.WriteString("PendingAuthSession(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("session_token=")
+ builder.WriteString(_m.SessionToken)
+ builder.WriteString(", ")
+ builder.WriteString("intent=")
+ builder.WriteString(_m.Intent)
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.TargetUserID; v != nil {
+ builder.WriteString("target_user_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("redirect_to=")
+ builder.WriteString(_m.RedirectTo)
+ builder.WriteString(", ")
+ builder.WriteString("resolved_email=")
+ builder.WriteString(_m.ResolvedEmail)
+ builder.WriteString(", ")
+ builder.WriteString("registration_password_hash=")
+ builder.WriteString(_m.RegistrationPasswordHash)
+ builder.WriteString(", ")
+ builder.WriteString("upstream_identity_claims=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims))
+ builder.WriteString(", ")
+ builder.WriteString("local_flow_state=")
+ builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState))
+ builder.WriteString(", ")
+ builder.WriteString("browser_session_key=")
+ builder.WriteString(_m.BrowserSessionKey)
+ builder.WriteString(", ")
+ builder.WriteString("completion_code_hash=")
+ builder.WriteString(_m.CompletionCodeHash)
+ builder.WriteString(", ")
+ if v := _m.CompletionCodeExpiresAt; v != nil {
+ builder.WriteString("completion_code_expires_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.EmailVerifiedAt; v != nil {
+ builder.WriteString("email_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.PasswordVerifiedAt; v != nil {
+ builder.WriteString("password_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.TotpVerifiedAt; v != nil {
+ builder.WriteString("totp_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ if v := _m.ConsumedAt; v != nil {
+ builder.WriteString("consumed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PendingAuthSessions is a parsable slice of PendingAuthSession.
+type PendingAuthSessions []*PendingAuthSession
diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go
new file mode 100644
index 00000000..8a3ac9bf
--- /dev/null
+++ b/backend/ent/pendingauthsession/pendingauthsession.go
@@ -0,0 +1,279 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the pendingauthsession type in the database.
+ Label = "pending_auth_session"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldSessionToken holds the string denoting the session_token field in the database.
+ FieldSessionToken = "session_token"
+ // FieldIntent holds the string denoting the intent field in the database.
+ FieldIntent = "intent"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldTargetUserID holds the string denoting the target_user_id field in the database.
+ FieldTargetUserID = "target_user_id"
+ // FieldRedirectTo holds the string denoting the redirect_to field in the database.
+ FieldRedirectTo = "redirect_to"
+ // FieldResolvedEmail holds the string denoting the resolved_email field in the database.
+ FieldResolvedEmail = "resolved_email"
+ // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database.
+ FieldRegistrationPasswordHash = "registration_password_hash"
+ // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database.
+ FieldUpstreamIdentityClaims = "upstream_identity_claims"
+ // FieldLocalFlowState holds the string denoting the local_flow_state field in the database.
+ FieldLocalFlowState = "local_flow_state"
+ // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database.
+ FieldBrowserSessionKey = "browser_session_key"
+ // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database.
+ FieldCompletionCodeHash = "completion_code_hash"
+ // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database.
+ FieldCompletionCodeExpiresAt = "completion_code_expires_at"
+ // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database.
+ FieldEmailVerifiedAt = "email_verified_at"
+ // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database.
+ FieldPasswordVerifiedAt = "password_verified_at"
+ // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database.
+ FieldTotpVerifiedAt = "totp_verified_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldConsumedAt holds the string denoting the consumed_at field in the database.
+ FieldConsumedAt = "consumed_at"
+ // EdgeTargetUser holds the string denoting the target_user edge name in mutations.
+ EdgeTargetUser = "target_user"
+ // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations.
+ EdgeAdoptionDecision = "adoption_decision"
+ // Table holds the table name of the pendingauthsession in the database.
+ Table = "pending_auth_sessions"
+ // TargetUserTable is the table that holds the target_user relation/edge.
+ TargetUserTable = "pending_auth_sessions"
+ // TargetUserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ TargetUserInverseTable = "users"
+ // TargetUserColumn is the table column denoting the target_user relation/edge.
+ TargetUserColumn = "target_user_id"
+ // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge.
+ AdoptionDecisionTable = "identity_adoption_decisions"
+ // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge.
+ AdoptionDecisionColumn = "pending_auth_session_id"
+)
+
+// Columns holds all SQL columns for pendingauthsession fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldSessionToken,
+ FieldIntent,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldTargetUserID,
+ FieldRedirectTo,
+ FieldResolvedEmail,
+ FieldRegistrationPasswordHash,
+ FieldUpstreamIdentityClaims,
+ FieldLocalFlowState,
+ FieldBrowserSessionKey,
+ FieldCompletionCodeHash,
+ FieldCompletionCodeExpiresAt,
+ FieldEmailVerifiedAt,
+ FieldPasswordVerifiedAt,
+ FieldTotpVerifiedAt,
+ FieldExpiresAt,
+ FieldConsumedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ SessionTokenValidator func(string) error
+ // IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ IntentValidator func(string) error
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultRedirectTo holds the default value on creation for the "redirect_to" field.
+ DefaultRedirectTo string
+ // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field.
+ DefaultResolvedEmail string
+ // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field.
+ DefaultRegistrationPasswordHash string
+ // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field.
+ DefaultUpstreamIdentityClaims func() map[string]interface{}
+ // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field.
+ DefaultLocalFlowState func() map[string]interface{}
+ // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field.
+ DefaultBrowserSessionKey string
+ // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field.
+ DefaultCompletionCodeHash string
+)
+
+// OrderOption defines the ordering options for the PendingAuthSession queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// BySessionToken orders the results by the session_token field.
+func BySessionToken(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSessionToken, opts...).ToFunc()
+}
+
+// ByIntent orders the results by the intent field.
+func ByIntent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntent, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByTargetUserID orders the results by the target_user_id field.
+func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTargetUserID, opts...).ToFunc()
+}
+
+// ByRedirectTo orders the results by the redirect_to field.
+func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRedirectTo, opts...).ToFunc()
+}
+
+// ByResolvedEmail orders the results by the resolved_email field.
+func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc()
+}
+
+// ByRegistrationPasswordHash orders the results by the registration_password_hash field.
+func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc()
+}
+
+// ByBrowserSessionKey orders the results by the browser_session_key field.
+func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc()
+}
+
+// ByCompletionCodeHash orders the results by the completion_code_hash field.
+func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc()
+}
+
+// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field.
+func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc()
+}
+
+// ByEmailVerifiedAt orders the results by the email_verified_at field.
+func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc()
+}
+
+// ByPasswordVerifiedAt orders the results by the password_verified_at field.
+func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc()
+}
+
+// ByTotpVerifiedAt orders the results by the totp_verified_at field.
+func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByConsumedAt orders the results by the consumed_at field.
+func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldConsumedAt, opts...).ToFunc()
+}
+
+// ByTargetUserField orders the results by target_user field.
+func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByAdoptionDecisionField orders the results by adoption_decision field.
+func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newTargetUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(TargetUserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+}
+func newAdoptionDecisionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+}
diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go
new file mode 100644
index 00000000..cb316f47
--- /dev/null
+++ b/backend/ent/pendingauthsession/where.go
@@ -0,0 +1,1262 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ.
+func SessionToken(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ.
+func Intent(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ.
+func TargetUserID(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ.
+func RedirectTo(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ.
+func ResolvedEmail(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ.
+func RegistrationPasswordHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ.
+func BrowserSessionKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ.
+func CompletionCodeHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ.
+func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ.
+func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ.
+func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ.
+func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ.
+func ConsumedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// SessionTokenEQ applies the EQ predicate on the "session_token" field.
+func SessionTokenEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// SessionTokenNEQ applies the NEQ predicate on the "session_token" field.
+func SessionTokenNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v))
+}
+
+// SessionTokenIn applies the In predicate on the "session_token" field.
+func SessionTokenIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenNotIn applies the NotIn predicate on the "session_token" field.
+func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenGT applies the GT predicate on the "session_token" field.
+func SessionTokenGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v))
+}
+
+// SessionTokenGTE applies the GTE predicate on the "session_token" field.
+func SessionTokenGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v))
+}
+
+// SessionTokenLT applies the LT predicate on the "session_token" field.
+func SessionTokenLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v))
+}
+
+// SessionTokenLTE applies the LTE predicate on the "session_token" field.
+func SessionTokenLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v))
+}
+
+// SessionTokenContains applies the Contains predicate on the "session_token" field.
+func SessionTokenContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v))
+}
+
+// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field.
+func SessionTokenHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v))
+}
+
+// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field.
+func SessionTokenHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v))
+}
+
+// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field.
+func SessionTokenEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v))
+}
+
+// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field.
+func SessionTokenContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v))
+}
+
+// IntentEQ applies the EQ predicate on the "intent" field.
+func IntentEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// IntentNEQ applies the NEQ predicate on the "intent" field.
+func IntentNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v))
+}
+
+// IntentIn applies the In predicate on the "intent" field.
+func IntentIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...))
+}
+
+// IntentNotIn applies the NotIn predicate on the "intent" field.
+func IntentNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...))
+}
+
+// IntentGT applies the GT predicate on the "intent" field.
+func IntentGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v))
+}
+
+// IntentGTE applies the GTE predicate on the "intent" field.
+func IntentGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v))
+}
+
+// IntentLT applies the LT predicate on the "intent" field.
+func IntentLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v))
+}
+
+// IntentLTE applies the LTE predicate on the "intent" field.
+func IntentLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v))
+}
+
+// IntentContains applies the Contains predicate on the "intent" field.
+func IntentContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v))
+}
+
+// IntentHasPrefix applies the HasPrefix predicate on the "intent" field.
+func IntentHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v))
+}
+
+// IntentHasSuffix applies the HasSuffix predicate on the "intent" field.
+func IntentHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v))
+}
+
+// IntentEqualFold applies the EqualFold predicate on the "intent" field.
+func IntentEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v))
+}
+
+// IntentContainsFold applies the ContainsFold predicate on the "intent" field.
+func IntentContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field.
+func TargetUserIDEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field.
+func TargetUserIDNEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDIn applies the In predicate on the "target_user_id" field.
+func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field.
+func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field.
+func TargetUserIDIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID))
+}
+
+// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field.
+func TargetUserIDNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID))
+}
+
+// RedirectToEQ applies the EQ predicate on the "redirect_to" field.
+func RedirectToEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field.
+func RedirectToNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v))
+}
+
+// RedirectToIn applies the In predicate on the "redirect_to" field.
+func RedirectToIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field.
+func RedirectToNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToGT applies the GT predicate on the "redirect_to" field.
+func RedirectToGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v))
+}
+
+// RedirectToGTE applies the GTE predicate on the "redirect_to" field.
+func RedirectToGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v))
+}
+
+// RedirectToLT applies the LT predicate on the "redirect_to" field.
+func RedirectToLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v))
+}
+
+// RedirectToLTE applies the LTE predicate on the "redirect_to" field.
+func RedirectToLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v))
+}
+
+// RedirectToContains applies the Contains predicate on the "redirect_to" field.
+func RedirectToContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v))
+}
+
+// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field.
+func RedirectToHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v))
+}
+
+// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field.
+func RedirectToHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v))
+}
+
+// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field.
+func RedirectToEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v))
+}
+
+// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field.
+func RedirectToContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v))
+}
+
+// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field.
+func ResolvedEmailEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field.
+func ResolvedEmailNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailIn applies the In predicate on the "resolved_email" field.
+func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field.
+func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailGT applies the GT predicate on the "resolved_email" field.
+func ResolvedEmailGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field.
+func ResolvedEmailGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLT applies the LT predicate on the "resolved_email" field.
+func ResolvedEmailLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field.
+func ResolvedEmailLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field.
+func ResolvedEmailContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field.
+func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field.
+func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field.
+func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field.
+func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field.
+func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field.
+func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field.
+func BrowserSessionKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field.
+func BrowserSessionKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field.
+func BrowserSessionKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field.
+func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field.
+func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field.
+func CompletionCodeHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field.
+func CompletionCodeHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field.
+func CompletionCodeHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt))
+}
+
+// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt))
+}
+
+// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field.
+func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field.
+func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field.
+func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt))
+}
+
+// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt))
+}
+
+// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt))
+}
+
+// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt))
+}
+
+// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt))
+}
+
+// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field.
+func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field.
+func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtIn applies the In predicate on the "consumed_at" field.
+func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field.
+func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtGT applies the GT predicate on the "consumed_at" field.
+func ConsumedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v))
+}
+
+// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field.
+func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtLT applies the LT predicate on the "consumed_at" field.
+func ConsumedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v))
+}
+
+// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field.
+func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field.
+func ConsumedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt))
+}
+
+// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field.
+func ConsumedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt))
+}
+
+// HasTargetUser applies the HasEdge predicate on the "target_user" edge.
+func HasTargetUser() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates).
+func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newTargetUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge.
+func HasAdoptionDecision() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates).
+func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newAdoptionDecisionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.NotPredicates(p))
+}
diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go
new file mode 100644
index 00000000..60276daa
--- /dev/null
+++ b/backend/ent/pendingauthsession_create.go
@@ -0,0 +1,1815 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity.
+type PendingAuthSessionCreate struct {
+ config
+ mutation *PendingAuthSessionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetSessionToken(v)
+ return _c
+}
+
+// SetIntent sets the "intent" field.
+func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetIntent(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate {
+ _c.mutation.SetTargetUserID(v)
+ return _c
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTargetUserID(*v)
+ }
+ return _c
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRedirectTo(v)
+ return _c
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRedirectTo(*v)
+ }
+ return _c
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetResolvedEmail(v)
+ return _c
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetResolvedEmail(*v)
+ }
+ return _c
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRegistrationPasswordHash(v)
+ return _c
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRegistrationPasswordHash(*v)
+ }
+ return _c
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ return _c
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetLocalFlowState(v)
+ return _c
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetBrowserSessionKey(v)
+ return _c
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetBrowserSessionKey(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeHash(v)
+ return _c
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeHash(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeExpiresAt(v)
+ return _c
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeExpiresAt(*v)
+ }
+ return _c
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetEmailVerifiedAt(v)
+ return _c
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetEmailVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetPasswordVerifiedAt(v)
+ return _c
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetPasswordVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetTotpVerifiedAt(v)
+ return _c
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTotpVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetConsumedAt(v)
+ return _c
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetConsumedAt(*v)
+ }
+ return _c
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate {
+ return _c.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate {
+ _c.mutation.SetAdoptionDecisionID(id)
+ return _c
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate {
+ if id != nil {
+ _c = _c.SetAdoptionDecisionID(*id)
+ }
+ return _c
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate {
+ return _c.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation {
+ return _c.mutation
+}
+
+// Save creates the PendingAuthSession in the database.
+func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PendingAuthSessionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := pendingauthsession.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ v := pendingauthsession.DefaultRedirectTo
+ _c.mutation.SetRedirectTo(v)
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ v := pendingauthsession.DefaultResolvedEmail
+ _c.mutation.SetResolvedEmail(v)
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ v := pendingauthsession.DefaultRegistrationPasswordHash
+ _c.mutation.SetRegistrationPasswordHash(v)
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ v := pendingauthsession.DefaultUpstreamIdentityClaims()
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ v := pendingauthsession.DefaultLocalFlowState()
+ _c.mutation.SetLocalFlowState(v)
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ v := pendingauthsession.DefaultBrowserSessionKey
+ _c.mutation.SetBrowserSessionKey(v)
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ v := pendingauthsession.DefaultCompletionCodeHash
+ _c.mutation.SetCompletionCodeHash(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PendingAuthSessionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)}
+ }
+ if _, ok := _c.mutation.SessionToken(); !ok {
+ return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)}
+ }
+ if v, ok := _c.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Intent(); !ok {
+ return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)}
+ }
+ if v, ok := _c.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)}
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)}
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)}
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)}
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)}
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)}
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)}
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)}
+ }
+ return nil
+}
+
+func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PendingAuthSession{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ _node.SessionToken = value
+ }
+ if value, ok := _c.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ _node.Intent = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ _node.RedirectTo = value
+ }
+ if value, ok := _c.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ _node.ResolvedEmail = value
+ }
+ if value, ok := _c.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ _node.RegistrationPasswordHash = value
+ }
+ if value, ok := _c.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ _node.UpstreamIdentityClaims = value
+ }
+ if value, ok := _c.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ _node.LocalFlowState = value
+ }
+ if value, ok := _c.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ _node.BrowserSessionKey = value
+ }
+ if value, ok := _c.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ _node.CompletionCodeHash = value
+ }
+ if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ _node.CompletionCodeExpiresAt = &value
+ }
+ if value, ok := _c.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ _node.EmailVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ _node.PasswordVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ _node.TotpVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := _c.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ _node.ConsumedAt = &value
+ }
+ if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TargetUserID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PendingAuthSessionUpsertOne is the builder for "upsert"-ing
+ // one PendingAuthSession node.
+ PendingAuthSessionUpsertOne struct {
+ create *PendingAuthSessionCreate
+ }
+
+ // PendingAuthSessionUpsert is the "OnConflict" setter.
+ PendingAuthSessionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpdatedAt)
+ return u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldSessionToken, v)
+ return u
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldSessionToken)
+ return u
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldIntent, v)
+ return u
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldIntent)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderSubject)
+ return u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTargetUserID, v)
+ return u
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRedirectTo, v)
+ return u
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRedirectTo)
+ return u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldResolvedEmail, v)
+ return u
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldResolvedEmail)
+ return u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRegistrationPasswordHash, v)
+ return u
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash)
+ return u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v)
+ return u
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims)
+ return u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldLocalFlowState, v)
+ return u
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldLocalFlowState)
+ return u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldBrowserSessionKey, v)
+ return u
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldBrowserSessionKey)
+ return u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeHash, v)
+ return u
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeHash)
+ return u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v)
+ return u
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldEmailVerifiedAt, v)
+ return u
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldPasswordVerifiedAt, v)
+ return u
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTotpVerifiedAt, v)
+ return u
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldExpiresAt)
+ return u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldConsumedAt, v)
+ return u
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk.
+type PendingAuthSessionCreateBulk struct {
+ config
+ err error
+ builders []*PendingAuthSessionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PendingAuthSession entities in the database.
+func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PendingAuthSession, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PendingAuthSessionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing
+// a bulk of PendingAuthSession nodes.
+type PendingAuthSessionUpsertBulk struct {
+ create *PendingAuthSessionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go
new file mode 100644
index 00000000..ee4fe605
--- /dev/null
+++ b/backend/ent/pendingauthsession_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity.
+type PendingAuthSessionDelete struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity.
+type PendingAuthSessionDeleteOne struct {
+ _d *PendingAuthSessionDelete
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go
new file mode 100644
index 00000000..78e29cd2
--- /dev/null
+++ b/backend/ent/pendingauthsession_query.go
@@ -0,0 +1,717 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities.
+type PendingAuthSessionQuery struct {
+ config
+ ctx *QueryContext
+ order []pendingauthsession.OrderOption
+ inters []Interceptor
+ predicates []predicate.PendingAuthSession
+ withTargetUser *UserQuery
+ withAdoptionDecision *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PendingAuthSessionQuery builder.
+func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryTargetUser chains the current query on the "target_user" edge.
+func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision chains the current query on the "adoption_decision" edge.
+func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first PendingAuthSession entity from the query.
+// Returns a *NotFoundError when no PendingAuthSession was found.
+func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{pendingauthsession.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PendingAuthSession ID from the query.
+// Returns a *NotFoundError when no PendingAuthSession ID was found.
+func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{pendingauthsession.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PendingAuthSession entity is found.
+// Returns a *NotFoundError when no PendingAuthSession entities are found.
+func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil, &NotSingularError{pendingauthsession.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PendingAuthSession ID in the query.
+// Returns a *NotSingularError when more than one PendingAuthSession ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{pendingauthsession.Label}
+ default:
+ err = &NotSingularError{pendingauthsession.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PendingAuthSessions.
+func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]()
+ return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PendingAuthSession IDs.
+func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PendingAuthSessionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]pendingauthsession.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PendingAuthSession{}, _q.predicates...),
+ withTargetUser: _q.withTargetUser.Clone(),
+ withAdoptionDecision: _q.withAdoptionDecision.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithTargetUser tells the query-builder to eager-load the nodes that are connected to
+// the "target_user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withTargetUser = query
+ return _q
+}
+
+// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecision = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// GroupBy(pendingauthsession.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PendingAuthSessionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = pendingauthsession.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// Select(pendingauthsession.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q}
+ sbuild.label = pendingauthsession.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations.
+func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) {
+ var (
+ nodes = []*PendingAuthSession{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withTargetUser != nil,
+ _q.withAdoptionDecision != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PendingAuthSession).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PendingAuthSession{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withTargetUser; query != nil {
+ if err := _q.loadTargetUser(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecision; query != nil {
+ if err := _q.loadAdoptionDecision(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*PendingAuthSession)
+ for i := range nodes {
+ if nodes[i].TargetUserID == nil {
+ continue
+ }
+ fk := *nodes[i].TargetUserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*PendingAuthSession)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.PendingAuthSessionID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for i := range fields {
+ if fields[i] != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withTargetUser != nil {
+ _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(pendingauthsession.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = pendingauthsession.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities.
+type PendingAuthSessionGroupBy struct {
+ selector
+ build *PendingAuthSessionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities.
+type PendingAuthSessionSelect struct {
+ *PendingAuthSessionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v)
+}
+
+func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go
new file mode 100644
index 00000000..00066f69
--- /dev/null
+++ b/backend/ent/pendingauthsession_update.go
@@ -0,0 +1,1178 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities.
+type PendingAuthSessionUpdate struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdate) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity.
+type PendingAuthSessionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PendingAuthSession entity.
+func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdateOne) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for _, f := range fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &PendingAuthSession{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index ef551940..0aa90b90 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -21,6 +21,12 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
+// AuthIdentity is the predicate function for authidentity builders.
+type AuthIdentity func(*sql.Selector)
+
+// AuthIdentityChannel is the predicate function for authidentitychannel builders.
+type AuthIdentityChannel func(*sql.Selector)
+
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
@@ -30,6 +36,9 @@ type Group func(*sql.Selector)
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
type IdempotencyRecord func(*sql.Selector)
+// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders.
+type IdentityAdoptionDecision func(*sql.Selector)
+
// PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector)
@@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector)
// PaymentProviderInstance is the predicate function for paymentproviderinstance builders.
type PaymentProviderInstance func(*sql.Selector)
+// PendingAuthSession is the predicate function for pendingauthsession builders.
+type PendingAuthSession func(*sql.Selector)
+
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index fbdd08c7..bdb7f7a9 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -10,12 +10,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -309,6 +313,120 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
+ authidentityMixin := schema.AuthIdentity{}.Mixin()
+ authidentityMixinFields0 := authidentityMixin[0].Fields()
+ _ = authidentityMixinFields0
+ authidentityFields := schema.AuthIdentity{}.Fields()
+ _ = authidentityFields
+ // authidentityDescCreatedAt is the schema descriptor for created_at field.
+ authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor()
+ // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time)
+ // authidentityDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor()
+ // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time)
+ // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentityDescProviderType is the schema descriptor for provider_type field.
+ authidentityDescProviderType := authidentityFields[1].Descriptor()
+ // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentity.ProviderTypeValidator = func() func(string) error {
+ validators := authidentityDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentityDescProviderKey is the schema descriptor for provider_key field.
+ authidentityDescProviderKey := authidentityFields[2].Descriptor()
+ // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error)
+ // authidentityDescProviderSubject is the schema descriptor for provider_subject field.
+ authidentityDescProviderSubject := authidentityFields[3].Descriptor()
+ // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error)
+ // authidentityDescMetadata is the schema descriptor for metadata field.
+ authidentityDescMetadata := authidentityFields[6].Descriptor()
+ // authidentity.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{})
+ authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin()
+ authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields()
+ _ = authidentitychannelMixinFields0
+ authidentitychannelFields := schema.AuthIdentityChannel{}.Fields()
+ _ = authidentitychannelFields
+ // authidentitychannelDescCreatedAt is the schema descriptor for created_at field.
+ authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor()
+ // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time)
+ // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor()
+ // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time)
+ // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentitychannelDescProviderType is the schema descriptor for provider_type field.
+ authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor()
+ // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentitychannel.ProviderTypeValidator = func() func(string) error {
+ validators := authidentitychannelDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescProviderKey is the schema descriptor for provider_key field.
+ authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor()
+ // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error)
+ // authidentitychannelDescChannel is the schema descriptor for channel field.
+ authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor()
+ // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ authidentitychannel.ChannelValidator = func() func(string) error {
+ validators := authidentitychannelDescChannel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(channel string) error {
+ for _, fn := range fns {
+ if err := fn(channel); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field.
+ authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor()
+ // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error)
+ // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field.
+ authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor()
+ // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error)
+ // authidentitychannelDescMetadata is the schema descriptor for metadata field.
+ authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor()
+ // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{})
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
@@ -512,6 +630,33 @@ func init() {
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
+ identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin()
+ identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields()
+ _ = identityadoptiondecisionMixinFields0
+ identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields()
+ _ = identityadoptiondecisionFields
+ // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field.
+ identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor()
+ // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field.
+ identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time)
+ // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field.
+ identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor()
+ // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time)
+ // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field.
+ identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor()
+ // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field.
+ identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool)
+ // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field.
+ identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor()
+ // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field.
+ identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool)
+ // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field.
+ identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor()
+ // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field.
+ identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time)
paymentauditlogFields := schema.PaymentAuditLog{}.Fields()
_ = paymentauditlogFields
// paymentauditlogDescOrderID is the schema descriptor for order_id field.
@@ -578,38 +723,42 @@ func init() {
paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor()
// paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error)
+ // paymentorderDescProviderKey is the schema descriptor for provider_key field.
+ paymentorderDescProviderKey := paymentorderFields[19].Descriptor()
+ // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error)
// paymentorderDescStatus is the schema descriptor for status field.
- paymentorderDescStatus := paymentorderFields[19].Descriptor()
+ paymentorderDescStatus := paymentorderFields[21].Descriptor()
// paymentorder.DefaultStatus holds the default value on creation for the status field.
paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string)
// paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save.
paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error)
// paymentorderDescRefundAmount is the schema descriptor for refund_amount field.
- paymentorderDescRefundAmount := paymentorderFields[20].Descriptor()
+ paymentorderDescRefundAmount := paymentorderFields[22].Descriptor()
// paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field.
paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64)
// paymentorderDescForceRefund is the schema descriptor for force_refund field.
- paymentorderDescForceRefund := paymentorderFields[23].Descriptor()
+ paymentorderDescForceRefund := paymentorderFields[25].Descriptor()
// paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field.
paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool)
// paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field.
- paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor()
+ paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor()
// paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error)
// paymentorderDescClientIP is the schema descriptor for client_ip field.
- paymentorderDescClientIP := paymentorderFields[32].Descriptor()
+ paymentorderDescClientIP := paymentorderFields[34].Descriptor()
// paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error)
// paymentorderDescSrcHost is the schema descriptor for src_host field.
- paymentorderDescSrcHost := paymentorderFields[33].Descriptor()
+ paymentorderDescSrcHost := paymentorderFields[35].Descriptor()
// paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error)
// paymentorderDescCreatedAt is the schema descriptor for created_at field.
- paymentorderDescCreatedAt := paymentorderFields[35].Descriptor()
+ paymentorderDescCreatedAt := paymentorderFields[37].Descriptor()
// paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time)
// paymentorderDescUpdatedAt is the schema descriptor for updated_at field.
- paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor()
+ paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor()
// paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time)
// paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
@@ -682,6 +831,113 @@ func init() {
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time)
+ pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin()
+ pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields()
+ _ = pendingauthsessionMixinFields0
+ pendingauthsessionFields := schema.PendingAuthSession{}.Fields()
+ _ = pendingauthsessionFields
+ // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field.
+ pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor()
+ // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field.
+ pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time)
+ // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field.
+ pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor()
+ // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time)
+ // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // pendingauthsessionDescSessionToken is the schema descriptor for session_token field.
+ pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor()
+ // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ pendingauthsession.SessionTokenValidator = func() func(string) error {
+ validators := pendingauthsessionDescSessionToken.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(session_token string) error {
+ for _, fn := range fns {
+ if err := fn(session_token); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescIntent is the schema descriptor for intent field.
+ pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor()
+ // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ pendingauthsession.IntentValidator = func() func(string) error {
+ validators := pendingauthsessionDescIntent.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(intent string) error {
+ for _, fn := range fns {
+ if err := fn(intent); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderType is the schema descriptor for provider_type field.
+ pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor()
+ // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ pendingauthsession.ProviderTypeValidator = func() func(string) error {
+ validators := pendingauthsessionDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field.
+ pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor()
+ // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error)
+ // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field.
+ pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor()
+ // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error)
+ // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field.
+ pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor()
+ // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field.
+ pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string)
+ // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field.
+ pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor()
+ // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field.
+ pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string)
+ // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field.
+ pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor()
+ // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field.
+ pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string)
+ // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field.
+ pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor()
+ // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field.
+ pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{})
+ // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field.
+ pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor()
+ // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field.
+ pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{})
+ // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field.
+ pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor()
+ // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field.
+ pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string)
+ // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field.
+ pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor()
+ // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field.
+ pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -1297,20 +1553,26 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescSignupSource is the schema descriptor for signup_source field.
+ userDescSignupSource := userFields[11].Descriptor()
+ // user.DefaultSignupSource holds the default value on creation for the signup_source field.
+ user.DefaultSignupSource = userDescSignupSource.Default.(string)
+ // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
- userDescBalanceNotifyEnabled := userFields[11].Descriptor()
+ userDescBalanceNotifyEnabled := userFields[14].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
- userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
+ userDescBalanceNotifyThresholdType := userFields[15].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
- userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
+ userDescBalanceNotifyExtraEmails := userFields[17].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
- userDescTotalRecharged := userFields[15].Descriptor()
+ userDescTotalRecharged := userFields[18].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go
new file mode 100644
index 00000000..0b1b56ab
--- /dev/null
+++ b/backend/ent/schema/auth_identity.go
@@ -0,0 +1,94 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var authProviderTypes = map[string]struct{}{
+ "email": {},
+ "linuxdo": {},
+ "oidc": {},
+ "wechat": {},
+}
+
+func validateAuthProviderType(value string) error {
+ if _, ok := authProviderTypes[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid auth provider type %q", value)
+}
+
+// AuthIdentity stores the canonical login identity for an account.
+type AuthIdentity struct {
+ ent.Schema
+}
+
+func (AuthIdentity) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identities"},
+ }
+}
+
+func (AuthIdentity) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentity) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("user_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("issuer").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentity) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("user", User.Type).
+ Ref("auth_identities").
+ Field("user_id").
+ Required().
+ Unique(),
+ edge.To("channels", AuthIdentityChannel.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
+ }
+}
+
+func (AuthIdentity) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "provider_subject").Unique(),
+ index.Fields("user_id"),
+ index.Fields("user_id", "provider_type"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go
new file mode 100644
index 00000000..69f2ad02
--- /dev/null
+++ b/backend/ent/schema/auth_identity_channel.go
@@ -0,0 +1,72 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity.
+type AuthIdentityChannel struct {
+ ent.Schema
+}
+
+func (AuthIdentityChannel) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identity_channels"},
+ }
+}
+
+func (AuthIdentityChannel) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentityChannel) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("identity_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel").
+ MaxLen(20).
+ NotEmpty(),
+ field.String("channel_app_id").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentityChannel) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("identity", AuthIdentity.Type).
+ Ref("channels").
+ Field("identity_id").
+ Required().
+ Unique(),
+ }
+}
+
+func (AuthIdentityChannel) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go
new file mode 100644
index 00000000..fbb93236
--- /dev/null
+++ b/backend/ent/schema/auth_identity_schema_test.go
@@ -0,0 +1,168 @@
+package schema
+
+import (
+ "testing"
+
+ "entgo.io/ent"
+ "entgo.io/ent/entc/load"
+ "entgo.io/ent/schema/field"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationSchemas(t *testing.T) {
+ spec, err := (&load.Config{Path: "."}).Load()
+ require.NoError(t, err)
+
+ schemas := map[string]*load.Schema{}
+ for _, schema := range spec.Schemas {
+ schemas[schema.Name] = schema
+ }
+
+ authIdentity := requireSchema(t, schemas, "AuthIdentity")
+ requireSchemaFields(t, authIdentity,
+ "user_id",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "verified_at",
+ "issuer",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject")
+
+ authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel")
+ requireSchemaFields(t, authIdentityChannel,
+ "identity_id",
+ "provider_type",
+ "provider_key",
+ "channel",
+ "channel_app_id",
+ "channel_subject",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject")
+
+ pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession")
+ requireSchemaFields(t, pendingAuthSession,
+ "intent",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "target_user_id",
+ "redirect_to",
+ "resolved_email",
+ "registration_password_hash",
+ "upstream_identity_claims",
+ "local_flow_state",
+ "browser_session_key",
+ "completion_code_hash",
+ "completion_code_expires_at",
+ "email_verified_at",
+ "password_verified_at",
+ "totp_verified_at",
+ "expires_at",
+ "consumed_at",
+ )
+
+ adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision")
+ requireSchemaFields(t, adoptionDecision,
+ "pending_auth_session_id",
+ "identity_id",
+ "adopt_display_name",
+ "adopt_avatar",
+ "decided_at",
+ )
+ requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id")
+
+ userSchema := requireSchema(t, schemas, "User")
+ requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
+ signupSource := requireSchemaField(t, userSchema, "signup_source")
+ require.Equal(t, field.TypeString, signupSource.Info.Type)
+ require.True(t, signupSource.Default)
+ require.Equal(t, "email", signupSource.DefaultValue)
+ require.Equal(t, 1, signupSource.Validators)
+
+ validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
+ for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
+ require.NoError(t, validator(value))
+ }
+ require.Error(t, validator("github"))
+}
+
+func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
+ t.Helper()
+
+ schema, ok := schemas[name]
+ require.True(t, ok, "schema %s should exist", name)
+ return schema
+}
+
+func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
+ t.Helper()
+
+ fields := map[string]struct{}{}
+ for _, field := range schema.Fields {
+ fields[field.Name] = struct{}{}
+ }
+
+ for _, name := range names {
+ _, ok := fields[name]
+ require.True(t, ok, "schema %s should include field %s", schema.Name, name)
+ }
+}
+
+func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
+ t.Helper()
+
+ for _, schemaField := range schema.Fields {
+ if schemaField.Name == name {
+ return schemaField
+ }
+ }
+
+ require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
+ return nil
+}
+
+func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
+ t.Helper()
+
+ for _, entField := range fields {
+ descriptor := entField.Descriptor()
+ if descriptor.Name != name {
+ continue
+ }
+ require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
+ validator, ok := descriptor.Validators[0].(func(string) error)
+ require.True(t, ok, "field %s validator should be func(string) error", name)
+ return validator
+ }
+
+ require.Failf(t, "missing field validator", "schema should include field %s", name)
+ return nil
+}
+
+func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
+ t.Helper()
+
+ for _, index := range schema.Indexes {
+ if !index.Unique {
+ continue
+ }
+ if len(index.Fields) != len(fields) {
+ continue
+ }
+ match := true
+ for i := range fields {
+ if index.Fields[i] != fields[i] {
+ match = false
+ break
+ }
+ }
+ if match {
+ return
+ }
+ }
+
+ require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields)
+}
diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go
new file mode 100644
index 00000000..9fdd26fb
--- /dev/null
+++ b/backend/ent/schema/identity_adoption_decision.go
@@ -0,0 +1,70 @@
+package schema
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow.
+type IdentityAdoptionDecision struct {
+ ent.Schema
+}
+
+func (IdentityAdoptionDecision) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "identity_adoption_decisions"},
+ }
+}
+
+func (IdentityAdoptionDecision) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (IdentityAdoptionDecision) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("pending_auth_session_id"),
+ field.Int64("identity_id").
+ Optional().
+ Nillable(),
+ field.Bool("adopt_display_name").
+ Default(false),
+ field.Bool("adopt_avatar").
+ Default(false),
+ field.Time("decided_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (IdentityAdoptionDecision) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("pending_auth_session", PendingAuthSession.Type).
+ Ref("adoption_decision").
+ Field("pending_auth_session_id").
+ Required().
+ Unique(),
+ edge.From("identity", AuthIdentity.Type).
+ Ref("adoption_decisions").
+ Field("identity_id").
+ Unique(),
+ }
+}
+
+func (IdentityAdoptionDecision) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("pending_auth_session_id").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go
index a9576d2a..d25d1e5e 100644
--- a/backend/ent/schema/payment_order.go
+++ b/backend/ent/schema/payment_order.go
@@ -91,6 +91,13 @@ func (PaymentOrder) Fields() []ent.Field {
Optional().
Nillable().
MaxLen(64),
+ field.String("provider_key").
+ Optional().
+ Nillable().
+ MaxLen(30),
+ field.JSON("provider_snapshot", map[string]any{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// 状态
field.String("status").
@@ -178,7 +185,9 @@ func (PaymentOrder) Edges() []ent.Edge {
func (PaymentOrder) Indexes() []ent.Index {
return []ent.Index{
- index.Fields("out_trade_no"),
+ index.Fields("out_trade_no").
+ Unique().
+ Annotations(entsql.IndexWhere("out_trade_no <> ''")),
index.Fields("user_id"),
index.Fields("status"),
index.Fields("expires_at"),
diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go
new file mode 100644
index 00000000..7e95f085
--- /dev/null
+++ b/backend/ent/schema/pending_auth_session.go
@@ -0,0 +1,135 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var pendingAuthIntents = map[string]struct{}{
+ "login": {},
+ "bind_current_user": {},
+ "adopt_existing_user_by_email": {},
+}
+
+func validatePendingAuthIntent(value string) error {
+ if _, ok := pendingAuthIntents[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid pending auth intent %q", value)
+}
+
+// PendingAuthSession stores a short-lived post-auth decision session.
+type PendingAuthSession struct {
+ ent.Schema
+}
+
+func (PendingAuthSession) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "pending_auth_sessions"},
+ }
+}
+
+func (PendingAuthSession) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (PendingAuthSession) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("session_token").
+ MaxLen(255).
+ NotEmpty(),
+ field.String("intent").
+ MaxLen(40).
+ NotEmpty().
+ Validate(validatePendingAuthIntent),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Int64("target_user_id").
+ Optional().
+ Nillable(),
+ field.String("redirect_to").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("resolved_email").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("registration_password_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("upstream_identity_claims", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.JSON("local_flow_state", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.String("browser_session_key").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("completion_code_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("completion_code_expires_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("email_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("password_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("totp_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("expires_at").
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("consumed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PendingAuthSession) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("target_user", User.Type).
+ Ref("pending_auth_sessions").
+ Field("target_user_id").
+ Unique(),
+ edge.To("adoption_decision", IdentityAdoptionDecision.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)).
+ Unique(),
+ }
+}
+
+func (PendingAuthSession) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("session_token").Unique(),
+ index.Fields("target_user_id"),
+ index.Fields("expires_at"),
+ index.Fields("provider_type", "provider_key", "provider_subject"),
+ index.Fields("completion_code_hash"),
+ }
+}
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index ef52e985..c0f0bdc1 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -1,6 +1,8 @@
package schema
import (
+ "fmt"
+
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -72,6 +74,24 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+ field.String("signup_source").
+ Validate(func(value string) error {
+ switch value {
+ case "email", "linuxdo", "wechat", "oidc":
+ return nil
+ default:
+ return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
+ }
+ }).
+ Default("email"),
+ field.Time("last_login_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("last_active_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知
field.Bool("balance_notify_enabled").
@@ -104,6 +124,9 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
+ edge.To("auth_identities", AuthIdentity.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index bb3139d5..bde3e35b 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -24,18 +24,26 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -202,12 +210,16 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
+ tx.AuthIdentity = NewAuthIdentityClient(tx.config)
+ tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
+ tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
+ tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 9fa91f74..66f33623 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // SignupSource holds the value of the "signup_source" field.
+ SignupSource string `json:"signup_source,omitempty"`
+ // LastLoginAt holds the value of the "last_login_at" field.
+ LastLoginAt *time.Time `json:"last_login_at,omitempty"`
+ // LastActiveAt holds the value of the "last_active_at" field.
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
@@ -83,11 +89,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
+ // AuthIdentities holds the value of the auth_identities edge.
+ AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
+ // PendingAuthSessions holds the value of the pending_auth_sessions edge.
+ PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [11]bool
+ loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"}
}
+// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
+ if e.loadedTypes[10] {
+ return e.AuthIdentities, nil
+ }
+ return nil, &NotLoadedError{edge: "auth_identities"}
+}
+
+// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
+ if e.loadedTypes[11] {
+ return e.PendingAuthSessions, nil
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_sessions"}
+}
+
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
- if e.loadedTypes[10] {
+ if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldSignupSource:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field signup_source", values[i])
+ } else if value.Valid {
+ _m.SignupSource = value.String
+ }
+ case user.FieldLastLoginAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
+ } else if value.Valid {
+ _m.LastLoginAt = new(time.Time)
+ *_m.LastLoginAt = value.Time
+ }
+ case user.FieldLastActiveAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
+ } else if value.Valid {
+ _m.LastActiveAt = new(time.Time)
+ *_m.LastActiveAt = value.Time
+ }
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
@@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m)
}
+// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
+func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
+ return NewUserClient(_m.config).QueryAuthIdentities(_m)
+}
+
+// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
+func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
+}
+
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -482,6 +540,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
+ builder.WriteString("signup_source=")
+ builder.WriteString(_m.SignupSource)
+ builder.WriteString(", ")
+ if v := _m.LastLoginAt; v != nil {
+ builder.WriteString("last_login_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.LastActiveAt; v != nil {
+ builder.WriteString("last_active_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index d88a3a38..567e3b14 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldSignupSource holds the string denoting the signup_source field in the database.
+ FieldSignupSource = "signup_source"
+ // FieldLastLoginAt holds the string denoting the last_login_at field in the database.
+ FieldLastLoginAt = "last_login_at"
+ // FieldLastActiveAt holds the string denoting the last_active_at field in the database.
+ FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
@@ -73,6 +79,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders"
+ // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
+ EdgeAuthIdentities = "auth_identities"
+ // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
+ EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -145,6 +155,20 @@ const (
PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id"
+ // AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
+ AuthIdentitiesTable = "auth_identities"
+ // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ AuthIdentitiesInverseTable = "auth_identities"
+ // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
+ AuthIdentitiesColumn = "user_id"
+ // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
+ PendingAuthSessionsTable = "pending_auth_sessions"
+ // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionsInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
+ PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -171,6 +195,9 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldSignupSource,
+ FieldLastLoginAt,
+ FieldLastActiveAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
@@ -232,6 +259,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultSignupSource holds the default value on creation for the "signup_source" field.
+ DefaultSignupSource string
+ // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
@@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// BySignupSource orders the results by the signup_source field.
+func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
+}
+
+// ByLastLoginAt orders the results by the last_login_at field.
+func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
+}
+
+// ByLastActiveAt orders the results by the last_active_at field.
+func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
+}
+
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
@@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
+// ByAuthIdentitiesCount orders the results by auth_identities count.
+func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
+ }
+}
+
+// ByAuthIdentities orders the results by auth_identities terms.
+func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
+func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
+ }
+}
+
+// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
+func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
)
}
+func newAuthIdentitiesStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+}
+func newPendingAuthSessionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 2788aa7a..cbcfcc26 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
+func SignupSource(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
+func LastLoginAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
+func LastActiveAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
+func SignupSourceEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
+func SignupSourceNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
+}
+
+// SignupSourceIn applies the In predicate on the "signup_source" field.
+func SignupSourceIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
+func SignupSourceNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceGT applies the GT predicate on the "signup_source" field.
+func SignupSourceGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSignupSource, v))
+}
+
+// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
+func SignupSourceGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSignupSource, v))
+}
+
+// SignupSourceLT applies the LT predicate on the "signup_source" field.
+func SignupSourceLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSignupSource, v))
+}
+
+// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
+func SignupSourceLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSignupSource, v))
+}
+
+// SignupSourceContains applies the Contains predicate on the "signup_source" field.
+func SignupSourceContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldSignupSource, v))
+}
+
+// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
+func SignupSourceHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
+}
+
+// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
+func SignupSourceHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
+}
+
+// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
+func SignupSourceEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
+}
+
+// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
+func SignupSourceContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
+}
+
+// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
+func LastLoginAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
+func LastLoginAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIn applies the In predicate on the "last_login_at" field.
+func LastLoginAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
+func LastLoginAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
+func LastLoginAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
+func LastLoginAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
+func LastLoginAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
+func LastLoginAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
+func LastLoginAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
+}
+
+// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
+func LastLoginAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
+}
+
+// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
+func LastActiveAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
+func LastActiveAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIn applies the In predicate on the "last_active_at" field.
+func LastActiveAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
+func LastActiveAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
+func LastActiveAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
+func LastActiveAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
+func LastActiveAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
+func LastActiveAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
+func LastActiveAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
+}
+
+// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
+func LastActiveAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
+}
+
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
})
}
+// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
+func HasAuthIdentities() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
+func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newAuthIdentitiesStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
+func HasPendingAuthSessions() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
+func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newPendingAuthSessionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index fbc64f9c..db95e813 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetSignupSource sets the "signup_source" field.
+func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
+ _c.mutation.SetSignupSource(v)
+ return _c
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
+ if v != nil {
+ _c.SetSignupSource(*v)
+ }
+ return _c
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastLoginAt(v)
+ return _c
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastLoginAt(*v)
+ }
+ return _c
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastActiveAt(v)
+ return _c
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastActiveAt(*v)
+ }
+ return _c
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddAuthIdentityIDs(ids...)
+ return _c
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddPendingAuthSessionIDs(ids...)
+ return _c
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ v := user.DefaultSignupSource
+ _c.mutation.SetSignupSource(v)
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -589,6 +667,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
+ }
+ if v, ok := _c.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
@@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ _node.SignupSource = value
+ }
+ if value, ok := _c.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ _node.LastLoginAt = &value
+ }
+ if value, ok := _c.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ _node.LastActiveAt = &value
+ }
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
@@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
@@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
+ u.Set(user.FieldSignupSource, v)
+ return u
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
+ u.SetExcluded(user.FieldSignupSource)
+ return u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastLoginAt, v)
+ return u
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastLoginAt)
+ return u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
+ u.SetNull(user.FieldLastLoginAt)
+ return u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastActiveAt, v)
+ return u
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastActiveAt)
+ return u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
+ u.SetNull(user.FieldLastActiveAt)
+ return u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
@@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go
index 113d87ac..f1ee5cfe 100644
--- a/backend/ent/user_query.go
+++ b/backend/ent/user_query.go
@@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery
+ withAuthIdentities *AuthIdentityQuery
+ withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities chains the current query on the "auth_identities" edge.
+func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
+func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(),
+ withAuthIdentities: _q.withAuthIdentities.Clone(),
+ withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q
}
+// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
+// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAuthIdentities = query
+ return _q
+}
+
+// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSessions = query
+ return _q
+}
+
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
- loadedTypes = [11]bool{
+ loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil,
+ _q.withAuthIdentities != nil,
+ _q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
+ if query := _q.withAuthIdentities; query != nil {
+ if err := _q.loadAuthIdentities(ctx, query, nodes,
+ func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
+ func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withPendingAuthSessions; query != nil {
+ if err := _q.loadPendingAuthSessions(ctx, query, nodes,
+ func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
+ func(n *User, e *PendingAuthSession) {
+ n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
}
return nil
}
+func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentity.FieldUserID)
+ }
+ query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
+ }
+ query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TargetUserID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 6b355247..677eeb6b 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/go.mod b/backend/go.mod
index 05e13981..2067a03b 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -1,9 +1,9 @@
module github.com/Wei-Shaw/sub2api
-go 1.25.0
+go 1.26.2
require (
- connectrpc.com/connect v1.19.1
+ connectrpc.com/connect v1.19.2
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alitto/pond/v2 v2.6.2
@@ -41,6 +41,7 @@ require (
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.49.0
+ golang.org/x/image v0.39.0
golang.org/x/net v0.52.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
@@ -105,7 +106,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
- github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -174,10 +174,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
- golang.org/x/mod v0.33.0 // indirect
+ golang.org/x/mod v0.34.0 // indirect
golang.org/x/sys v0.42.0 // indirect
- golang.org/x/text v0.35.0 // indirect
- golang.org/x/tools v0.42.0 // indirect
+ golang.org/x/text v0.36.0 // indirect
+ golang.org/x/tools v0.43.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index 691a7fd6..7c9621ef 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -1,7 +1,7 @@
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 h1:E0wvcUXTkgyN4wy4LGtNzMNGMytJN8afmIWXJVMi4cc=
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9/go.mod h1:Oe1xWPuu5q9LzyrWfbZmEZxFYeu4BHTyzfjeW2aZp/w=
-connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
-connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
+connectrpc.com/connect v1.19.2 h1:McQ83FGdzL+t60peksi0gXC7MQ/iLKgLduAnThbM0mo=
+connectrpc.com/connect v1.19.2/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
entgo.io/ent v0.14.5 h1:Rj2WOYJtCkWyFo6a+5wB3EfBRP0rnx1fMk6gGA0UUe4=
@@ -164,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
-github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
-github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -185,8 +183,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
-github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
-github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -222,8 +218,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
-github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -257,8 +251,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
-github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
-github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -288,8 +280,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
-github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -322,8 +312,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
-github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
-github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@@ -419,8 +407,10 @@ golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
-golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
-golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
+golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
+golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
+golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
+golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
@@ -438,12 +428,12 @@ golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
-golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
-golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
+golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
+golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
-golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
-golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
+golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
+golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 87e3ff5a..7bb48ecc 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
+// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
+// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
+// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
+const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
+
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
@@ -65,6 +70,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
+ WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -185,26 +191,47 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
+type WeChatConnectConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ AppID string `mapstructure:"app_id"`
+ AppSecret string `mapstructure:"app_secret"`
+ OpenAppID string `mapstructure:"open_app_id"`
+ OpenAppSecret string `mapstructure:"open_app_secret"`
+ MPAppID string `mapstructure:"mp_app_id"`
+ MPAppSecret string `mapstructure:"mp_app_secret"`
+ MobileAppID string `mapstructure:"mobile_app_id"`
+ MobileAppSecret string `mapstructure:"mobile_app_secret"`
+ OpenEnabled bool `mapstructure:"open_enabled"`
+ MPEnabled bool `mapstructure:"mp_enabled"`
+ MobileEnabled bool `mapstructure:"mobile_enabled"`
+ Mode string `mapstructure:"mode"`
+ Scopes string `mapstructure:"scopes"`
+ RedirectURL string `mapstructure:"redirect_url"`
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
+}
+
type OIDCConnectConfig struct {
- Enabled bool `mapstructure:"enabled"`
- ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
- ClientID string `mapstructure:"client_id"`
- ClientSecret string `mapstructure:"client_secret"`
- IssuerURL string `mapstructure:"issuer_url"`
- DiscoveryURL string `mapstructure:"discovery_url"`
- AuthorizeURL string `mapstructure:"authorize_url"`
- TokenURL string `mapstructure:"token_url"`
- UserInfoURL string `mapstructure:"userinfo_url"`
- JWKSURL string `mapstructure:"jwks_url"`
- Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
- RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
- FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
- TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
- UsePKCE bool `mapstructure:"use_pkce"`
- ValidateIDToken bool `mapstructure:"validate_id_token"`
- AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
- ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
- RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
+ Enabled bool `mapstructure:"enabled"`
+ ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
+ ClientID string `mapstructure:"client_id"`
+ ClientSecret string `mapstructure:"client_secret"`
+ IssuerURL string `mapstructure:"issuer_url"`
+ DiscoveryURL string `mapstructure:"discovery_url"`
+ AuthorizeURL string `mapstructure:"authorize_url"`
+ TokenURL string `mapstructure:"token_url"`
+ UserInfoURL string `mapstructure:"userinfo_url"`
+ JWKSURL string `mapstructure:"jwks_url"`
+ Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
+ RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
+ TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
+ UsePKCE bool `mapstructure:"use_pkce"`
+ ValidateIDToken bool `mapstructure:"validate_id_token"`
+ UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
+ ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
+ AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
+ ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
+ RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
@@ -213,6 +240,225 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
+const (
+ defaultWeChatConnectMode = "open"
+ defaultWeChatConnectScopes = "snsapi_login"
+ defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback"
+)
+
+func firstNonEmptyString(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
+func normalizeWeChatConnectMode(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "mp":
+ return "mp"
+ case "mobile":
+ return "mobile"
+ default:
+ return defaultWeChatConnectMode
+ }
+}
+
+func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
+ mode = normalizeWeChatConnectMode(mode)
+ switch mode {
+ case "open":
+ if openEnabled {
+ return "open"
+ }
+ case "mp":
+ if mpEnabled {
+ return "mp"
+ }
+ case "mobile":
+ if mobileEnabled {
+ return "mobile"
+ }
+ }
+ switch {
+ case openEnabled:
+ return "open"
+ case mpEnabled:
+ return "mp"
+ case mobileEnabled:
+ return "mobile"
+ default:
+ return mode
+ }
+}
+
+func defaultWeChatConnectScopesForMode(mode string) string {
+ switch normalizeWeChatConnectMode(mode) {
+ case "mp":
+ return "snsapi_userinfo"
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func normalizeWeChatConnectScopes(raw, mode string) string {
+ switch normalizeWeChatConnectMode(mode) {
+ case "mp":
+ switch strings.TrimSpace(raw) {
+ case "snsapi_base":
+ return "snsapi_base"
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ default:
+ return defaultWeChatConnectScopesForMode(mode)
+ }
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
+ if viper.InConfig(configKey) {
+ return false
+ }
+ _, hasNewEnv := os.LookupEnv(envKey)
+ return !hasNewEnv
+}
+
+func hasExplicitConfigOrEnv(configKey, envKey string) bool {
+ if viper.InConfig(configKey) {
+ return true
+ }
+ _, ok := os.LookupEnv(envKey)
+ return ok
+}
+
+func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
+ if cfg == nil {
+ return
+ }
+
+ legacyOpenAppID := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
+ legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
+ if legacyOpenAppID != "" {
+ cfg.OpenAppID = legacyOpenAppID
+ }
+ }
+
+ legacyOpenAppSecret := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
+ legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
+ if legacyOpenAppSecret != "" {
+ cfg.OpenAppSecret = legacyOpenAppSecret
+ }
+ }
+
+ legacyMPAppID := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
+ legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
+ if legacyMPAppID != "" {
+ cfg.MPAppID = legacyMPAppID
+ }
+ }
+
+ legacyMPAppSecret := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
+ legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
+ if legacyMPAppSecret != "" {
+ cfg.MPAppSecret = legacyMPAppSecret
+ }
+ }
+
+ if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") {
+ if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" {
+ cfg.FrontendRedirectURL = legacyFrontend
+ }
+ }
+
+ hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != ""
+ hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != ""
+
+ if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) {
+ cfg.Enabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen {
+ cfg.OpenEnabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP {
+ cfg.MPEnabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") {
+ switch {
+ case hasLegacyMP && !hasLegacyOpen:
+ cfg.Mode = "mp"
+ case hasLegacyOpen:
+ cfg.Mode = "open"
+ }
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") {
+ switch {
+ case hasLegacyMP && !hasLegacyOpen:
+ cfg.Scopes = defaultWeChatConnectScopesForMode("mp")
+ case hasLegacyOpen:
+ cfg.Scopes = defaultWeChatConnectScopesForMode("open")
+ }
+ }
+}
+
+func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) {
+ if cfg == nil {
+ return
+ }
+
+ cfg.AppID = strings.TrimSpace(cfg.AppID)
+ cfg.AppSecret = strings.TrimSpace(cfg.AppSecret)
+ cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID)
+ cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret)
+ cfg.MPAppID = strings.TrimSpace(cfg.MPAppID)
+ cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret)
+ cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID)
+ cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret)
+ cfg.Mode = normalizeWeChatConnectMode(cfg.Mode)
+ cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL)
+ cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL)
+
+ cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID)
+ cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret)
+ cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID)
+ cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret)
+ cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID)
+ cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret)
+ cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID)
+ cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret)
+
+ if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled {
+ switch cfg.Mode {
+ case "mp":
+ cfg.MPEnabled = true
+ case "mobile":
+ cfg.MobileEnabled = true
+ default:
+ cfg.OpenEnabled = true
+ }
+ }
+ cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode)
+ cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode)
+ if cfg.FrontendRedirectURL == "" {
+ cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect
+ }
+}
+
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -1067,6 +1313,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
+ applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat)
+ normalizeWeChatConnectConfig(&cfg.WeChat)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
@@ -1084,6 +1332,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
+ cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
+ cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
@@ -1262,6 +1512,24 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
+ // WeChat Connect OAuth 登录
+ viper.SetDefault("wechat_connect.enabled", false)
+ viper.SetDefault("wechat_connect.app_id", "")
+ viper.SetDefault("wechat_connect.app_secret", "")
+ viper.SetDefault("wechat_connect.open_app_id", "")
+ viper.SetDefault("wechat_connect.open_app_secret", "")
+ viper.SetDefault("wechat_connect.mp_app_id", "")
+ viper.SetDefault("wechat_connect.mp_app_secret", "")
+ viper.SetDefault("wechat_connect.mobile_app_id", "")
+ viper.SetDefault("wechat_connect.mobile_app_secret", "")
+ viper.SetDefault("wechat_connect.open_enabled", false)
+ viper.SetDefault("wechat_connect.mp_enabled", false)
+ viper.SetDefault("wechat_connect.mobile_enabled", false)
+ viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode)
+ viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes)
+ viper.SetDefault("wechat_connect.redirect_url", "")
+ viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect)
+
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
@@ -1277,7 +1545,7 @@ func setDefaults() {
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
- viper.SetDefault("oidc_connect.use_pkce", false)
+ viper.SetDefault("oidc_connect.use_pkce", true)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
@@ -1476,7 +1744,7 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
+ viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
@@ -1713,9 +1981,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.LinuxDo.UsePKCE {
- return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
@@ -1746,6 +2011,45 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
+ if c.WeChat.Enabled {
+ weChat := c.WeChat
+ normalizeWeChatConnectConfig(&weChat)
+
+ if weChat.OpenEnabled {
+ if strings.TrimSpace(weChat.OpenAppID) == "" {
+ return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true")
+ }
+ if strings.TrimSpace(weChat.OpenAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true")
+ }
+ }
+ if weChat.MPEnabled {
+ if strings.TrimSpace(weChat.MPAppID) == "" {
+ return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true")
+ }
+ if strings.TrimSpace(weChat.MPAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true")
+ }
+ }
+ if weChat.MobileEnabled {
+ if strings.TrimSpace(weChat.MobileAppID) == "" {
+ return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true")
+ }
+ if strings.TrimSpace(weChat.MobileAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true")
+ }
+ }
+ if v := strings.TrimSpace(weChat.RedirectURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err)
+ }
+ warnIfInsecureURL("wechat_connect.redirect_url", v)
+ }
+ if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil {
+ return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err)
+ }
+ warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL)
+ }
if c.OIDC.Enabled {
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
@@ -1769,9 +2073,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.OIDC.UsePKCE {
- return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index cf58316c..6ba86aa1 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -225,6 +225,52 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
+func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
+ t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback")
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.True(t, cfg.WeChat.Enabled)
+ require.True(t, cfg.WeChat.OpenEnabled)
+ require.True(t, cfg.WeChat.MPEnabled)
+ require.False(t, cfg.WeChat.MobileEnabled)
+ require.Equal(t, "open", cfg.WeChat.Mode)
+ require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID)
+ require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret)
+ require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID)
+ require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret)
+ require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL)
+}
+
+func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.True(t, cfg.OIDC.UsePKCE)
+ require.True(t, cfg.OIDC.ValidateIDToken)
+ require.False(t, cfg.OIDC.UsePKCEExplicit)
+ require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
+}
+
+func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
+ t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.False(t, cfg.OIDC.UsePKCE)
+ require.False(t, cfg.OIDC.ValidateIDToken)
+ require.True(t, cfg.OIDC.UsePKCEExplicit)
+ require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
+}
+
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -334,7 +380,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
- cfg.LinuxDo.UsePKCE = false
+ cfg.LinuxDo.UsePKCE = true
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
@@ -346,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
-func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
+func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
@@ -363,11 +409,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
- if err == nil {
- t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
- }
- if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
- t.Fatalf("Validate() expected use_pkce error, got: %v", err)
+ if err != nil {
+ t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err)
}
}
@@ -389,6 +432,7 @@ func TestValidateOIDCScopesMustContainOpenID(t *testing.T) {
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "profile email"
+ cfg.OIDC.UsePKCE = true
err = cfg.Validate()
if err == nil {
@@ -418,6 +462,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.ValidateIDToken = true
+ cfg.OIDC.UsePKCE = true
err = cfg.Validate()
if err != nil {
@@ -425,6 +470,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
}
}
+func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.OIDC.Enabled = true
+ cfg.OIDC.ClientID = "oidc-client"
+ cfg.OIDC.ClientSecret = "oidc-secret"
+ cfg.OIDC.IssuerURL = "https://issuer.example.com"
+ cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
+ cfg.OIDC.TokenURL = "https://issuer.example.com/token"
+ cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo"
+ cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
+ cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
+ cfg.OIDC.Scopes = "openid email profile"
+ cfg.OIDC.UsePKCE = false
+ cfg.OIDC.ValidateIDToken = false
+ cfg.OIDC.JWKSURL = ""
+ cfg.OIDC.AllowedSigningAlgs = ""
+
+ err = cfg.Validate()
+ if err != nil {
+ t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err)
+ }
+}
+
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -840,6 +914,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
+ cfg.LinuxDo.UsePKCE = true
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
@@ -990,6 +1065,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
+ c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
@@ -998,6 +1074,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
+ c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index cba3ae21..ddeaab02 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -23,6 +23,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
+ router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
@@ -75,8 +76,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ bindBody := map[string]any{
+ "provider_type": "wechat",
+ "provider_key": "wechat-main",
+ "provider_subject": "union-123",
+ "metadata": map[string]any{"source": "admin-repair"},
+ "channel": map[string]any{
+ "channel": "open",
+ "channel_app_id": "wx-open",
+ "channel_subject": "openid-123",
+ },
+ }
+ body, _ := json.Marshal(bindBody)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
- body, _ := json.Marshal(createBody)
+ body, _ = json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
@@ -113,6 +132,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
}
+func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
+ router, adminSvc := setupAdminRouter()
+
+ body, err := json.Marshal(map[string]any{
+ "provider_type": "oidc",
+ "provider_key": "https://issuer.example",
+ "provider_subject": "subject-123",
+ "issuer": "https://issuer.example",
+ "metadata": map[string]any{"report_id": 12},
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
+ require.NotNil(t, adminSvc.boundAuthIdentity)
+ require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
+ require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
+ require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
+ require.Nil(t, adminSvc.boundAuthIdentity.Channel)
+ require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
+}
+
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 6d1ef1b6..3a395342 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -17,6 +17,8 @@ type stubAdminService struct {
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
+ boundAuthIdentity *service.AdminBindAuthIdentityInput
+ boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
@@ -42,6 +44,14 @@ type stubAdminService struct {
sortOrder string
calls int
}
+ lastListUsers struct {
+ page int
+ pageSize int
+ filters service.UserListFilters
+ sortBy string
+ sortOrder string
+ calls int
+ }
lastListProxies struct {
protocol string
status string
@@ -127,6 +137,12 @@ func newStubAdminService() *stubAdminService {
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
+ s.lastListUsers.page = page
+ s.lastListUsers.pageSize = pageSize
+ s.lastListUsers.filters = filters
+ s.lastListUsers.sortBy = sortBy
+ s.lastListUsers.sortOrder = sortOrder
+ s.lastListUsers.calls++
return s.users, int64(len(s.users)), nil
}
@@ -167,6 +183,52 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
+func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
+ s.boundAuthIdentityFor = userID
+ copied := input
+ if input.Metadata != nil {
+ copied.Metadata = map[string]any{}
+ for key, value := range input.Metadata {
+ copied.Metadata[key] = value
+ }
+ }
+ if input.Channel != nil {
+ channel := *input.Channel
+ if input.Channel.Metadata != nil {
+ channel.Metadata = map[string]any{}
+ for key, value := range input.Channel.Metadata {
+ channel.Metadata[key] = value
+ }
+ }
+ copied.Channel = &channel
+ }
+ s.boundAuthIdentity = &copied
+
+ now := time.Now().UTC()
+ result := &service.AdminBoundAuthIdentity{
+ UserID: userID,
+ ProviderType: input.ProviderType,
+ ProviderKey: input.ProviderKey,
+ ProviderSubject: input.ProviderSubject,
+ VerifiedAt: &now,
+ Issuer: input.Issuer,
+ Metadata: input.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if input.Channel != nil {
+ result.Channel = &service.AdminBoundAuthIdentityChannel{
+ Channel: input.Channel.Channel,
+ ChannelAppID: input.Channel.ChannelAppID,
+ ChannelSubject: input.Channel.ChannelSubject,
+ Metadata: input.Channel.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ }
+ return result, nil
+}
+
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go
index b0ed6aed..84359cd9 100644
--- a/backend/internal/handler/admin/payment_handler.go
+++ b/backend/internal/handler/admin/payment_handler.go
@@ -3,6 +3,7 @@ package admin
import (
"strconv"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrderDetail returns detailed information about a single order.
@@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
return
}
auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
- response.Success(c, gin.H{"order": order, "auditLogs": auditLogs})
+ response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs})
}
// CancelOrder cancels a pending order (admin).
@@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
response.Success(c, gin.H{"message": "fulfillment retried"})
}
+func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizeAdminPaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
// AdminProcessRefundRequest is the request body for admin refund processing.
type AdminProcessRefundRequest struct {
Amount float64 `json:"amount"`
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index bec0f126..a882d1a1 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -43,6 +43,15 @@ func scopesContainOpenID(scopes string) bool {
return false
}
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
@@ -73,6 +82,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
@@ -93,114 +107,136 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
- RegistrationEnabled: settings.RegistrationEnabled,
- EmailVerifyEnabled: settings.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: settings.PromoCodeEnabled,
- PasswordResetEnabled: settings.PasswordResetEnabled,
- FrontendURL: settings.FrontendURL,
- InvitationCodeEnabled: settings.InvitationCodeEnabled,
- TotpEnabled: settings.TotpEnabled,
- TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
- SMTPHost: settings.SMTPHost,
- SMTPPort: settings.SMTPPort,
- SMTPUsername: settings.SMTPUsername,
- SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
- SMTPFrom: settings.SMTPFrom,
- SMTPFromName: settings.SMTPFromName,
- SMTPUseTLS: settings.SMTPUseTLS,
- TurnstileEnabled: settings.TurnstileEnabled,
- TurnstileSiteKey: settings.TurnstileSiteKey,
- TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
- LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
- LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
- LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
- OIDCConnectEnabled: settings.OIDCConnectEnabled,
- OIDCConnectProviderName: settings.OIDCConnectProviderName,
- OIDCConnectClientID: settings.OIDCConnectClientID,
- OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
- OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
- OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
- OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
- OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
- OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
- OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
- OIDCConnectScopes: settings.OIDCConnectScopes,
- OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
- OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
- OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
- OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
- OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
- OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
- OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
- OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
- OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
- OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
- SiteName: settings.SiteName,
- SiteLogo: settings.SiteLogo,
- SiteSubtitle: settings.SiteSubtitle,
- APIBaseURL: settings.APIBaseURL,
- ContactInfo: settings.ContactInfo,
- DocURL: settings.DocURL,
- HomeContent: settings.HomeContent,
- HideCcsImportButton: settings.HideCcsImportButton,
- PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
- PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
- TableDefaultPageSize: settings.TableDefaultPageSize,
- TablePageSizeOptions: settings.TablePageSizeOptions,
- CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
- CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
- DefaultConcurrency: settings.DefaultConcurrency,
- DefaultBalance: settings.DefaultBalance,
- DefaultSubscriptions: defaultSubscriptions,
- EnableModelFallback: settings.EnableModelFallback,
- FallbackModelAnthropic: settings.FallbackModelAnthropic,
- FallbackModelOpenAI: settings.FallbackModelOpenAI,
- FallbackModelGemini: settings.FallbackModelGemini,
- FallbackModelAntigravity: settings.FallbackModelAntigravity,
- EnableIdentityPatch: settings.EnableIdentityPatch,
- IdentityPatchPrompt: settings.IdentityPatchPrompt,
- OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
- OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
- OpsQueryModeDefault: settings.OpsQueryModeDefault,
- OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
- BackendModeEnabled: settings.BackendModeEnabled,
- EnableFingerprintUnification: settings.EnableFingerprintUnification,
- EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
- EnableCCHSigning: settings.EnableCCHSigning,
- WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
- BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
- BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
- AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
- AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
- PaymentEnabled: paymentCfg.Enabled,
- PaymentMinAmount: paymentCfg.MinAmount,
- PaymentMaxAmount: paymentCfg.MaxAmount,
- PaymentDailyLimit: paymentCfg.DailyLimit,
- PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
- PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
- PaymentEnabledTypes: paymentCfg.EnabledTypes,
- PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
- PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
- PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
- PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
- PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
- PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
- PaymentHelpImageURL: paymentCfg.HelpImageURL,
- PaymentHelpText: paymentCfg.HelpText,
- PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
- PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
- PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
- PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
- PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
- })
+ payload := dto.SystemSettings{
+ RegistrationEnabled: settings.RegistrationEnabled,
+ EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ FrontendURL: settings.FrontendURL,
+ InvitationCodeEnabled: settings.InvitationCodeEnabled,
+ TotpEnabled: settings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
+ SMTPHost: settings.SMTPHost,
+ SMTPPort: settings.SMTPPort,
+ SMTPUsername: settings.SMTPUsername,
+ SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
+ SMTPFrom: settings.SMTPFrom,
+ SMTPFromName: settings.SMTPFromName,
+ SMTPUseTLS: settings.SMTPUseTLS,
+ TurnstileEnabled: settings.TurnstileEnabled,
+ TurnstileSiteKey: settings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
+ LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
+ LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: settings.WeChatConnectEnabled,
+ WeChatConnectAppID: settings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
+ WeChatConnectOpenAppID: settings.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecretConfigured: settings.WeChatConnectOpenAppSecretConfigured,
+ WeChatConnectMPAppID: settings.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecretConfigured: settings.WeChatConnectMPAppSecretConfigured,
+ WeChatConnectMobileAppID: settings.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecretConfigured: settings.WeChatConnectMobileAppSecretConfigured,
+ WeChatConnectOpenEnabled: settings.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: settings.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: settings.WeChatConnectMobileEnabled,
+ WeChatConnectMode: settings.WeChatConnectMode,
+ WeChatConnectScopes: settings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: settings.OIDCConnectEnabled,
+ OIDCConnectProviderName: settings.OIDCConnectProviderName,
+ OIDCConnectClientID: settings.OIDCConnectClientID,
+ OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
+ OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
+ OIDCConnectScopes: settings.OIDCConnectScopes,
+ OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
+ OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
+ SiteName: settings.SiteName,
+ SiteLogo: settings.SiteLogo,
+ SiteSubtitle: settings.SiteSubtitle,
+ APIBaseURL: settings.APIBaseURL,
+ ContactInfo: settings.ContactInfo,
+ DocURL: settings.DocURL,
+ HomeContent: settings.HomeContent,
+ HideCcsImportButton: settings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: settings.TableDefaultPageSize,
+ TablePageSizeOptions: settings.TablePageSizeOptions,
+ CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
+ CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
+ DefaultConcurrency: settings.DefaultConcurrency,
+ DefaultBalance: settings.DefaultBalance,
+ DefaultSubscriptions: defaultSubscriptions,
+ EnableModelFallback: settings.EnableModelFallback,
+ FallbackModelAnthropic: settings.FallbackModelAnthropic,
+ FallbackModelOpenAI: settings.FallbackModelOpenAI,
+ FallbackModelGemini: settings.FallbackModelGemini,
+ FallbackModelAntigravity: settings.FallbackModelAntigravity,
+ EnableIdentityPatch: settings.EnableIdentityPatch,
+ IdentityPatchPrompt: settings.IdentityPatchPrompt,
+ OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
+ OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
+ OpsQueryModeDefault: settings.OpsQueryModeDefault,
+ OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: settings.BackendModeEnabled,
+ EnableFingerprintUnification: settings.EnableFingerprintUnification,
+ EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
+ EnableCCHSigning: settings.EnableCCHSigning,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
+ PaymentEnabled: paymentCfg.Enabled,
+ PaymentMinAmount: paymentCfg.MinAmount,
+ PaymentMaxAmount: paymentCfg.MaxAmount,
+ PaymentDailyLimit: paymentCfg.DailyLimit,
+ PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
+ PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
+ PaymentEnabledTypes: paymentCfg.EnabledTypes,
+ PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
+ PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
+ PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
+ PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
+ PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
+ PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
+ PaymentHelpImageURL: paymentCfg.HelpImageURL,
+ PaymentHelpText: paymentCfg.HelpText,
+ PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
+ PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
+ PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
+ PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
+ PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
+ }
+ response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// UpdateSettingsRequest 更新设置请求
@@ -235,6 +271,24 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecret string `json:"wechat_connect_app_secret"`
+ WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
+ WeChatConnectOpenAppSecret string `json:"wechat_connect_open_app_secret"`
+ WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
+ WeChatConnectMPAppSecret string `json:"wechat_connect_mp_app_secret"`
+ WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
+ WeChatConnectMobileAppSecret string `json:"wechat_connect_mobile_app_secret"`
+ WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
+ WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
+ WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
@@ -250,8 +304,8 @@ type UpdateSettingsRequest struct {
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
- OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
- OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
+ OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"`
+ OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
@@ -276,9 +330,30 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
+ AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
+ AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
+ AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
+ AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
+ AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
+ AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
+ AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
+ AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
+ AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
+ AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
+ AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
+ AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
+ AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
+ AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
+ AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
+ AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
+ AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
+ AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
+ AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
+ ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -311,6 +386,15 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
+
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
@@ -357,6 +441,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// 验证参数
if req.DefaultConcurrency < 1 {
@@ -381,6 +470,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
+ req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
+ req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
+ req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
+ req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@@ -459,7 +552,141 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ if req.WeChatConnectEnabled {
+ req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
+ req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
+ req.WeChatConnectOpenAppID = strings.TrimSpace(req.WeChatConnectOpenAppID)
+ req.WeChatConnectOpenAppSecret = strings.TrimSpace(req.WeChatConnectOpenAppSecret)
+ req.WeChatConnectMPAppID = strings.TrimSpace(req.WeChatConnectMPAppID)
+ req.WeChatConnectMPAppSecret = strings.TrimSpace(req.WeChatConnectMPAppSecret)
+ req.WeChatConnectMobileAppID = strings.TrimSpace(req.WeChatConnectMobileAppID)
+ req.WeChatConnectMobileAppSecret = strings.TrimSpace(req.WeChatConnectMobileAppSecret)
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode))
+ req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
+ req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
+ req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL))
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL))
+ if req.WeChatConnectMode == "" {
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode))
+ }
+ if req.WeChatConnectScopes == "" {
+ req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes)
+ }
+
+ if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled {
+ response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time")
+ return
+ }
+ if req.WeChatConnectMode != "" {
+ switch req.WeChatConnectMode {
+ case "open", "mp", "mobile":
+ default:
+ response.BadRequest(c, "WeChat mode must be open, mp, or mobile")
+ return
+ }
+ }
+ if !req.WeChatConnectOpenEnabled && !req.WeChatConnectMPEnabled && !req.WeChatConnectMobileEnabled {
+ switch req.WeChatConnectMode {
+ case "mp":
+ req.WeChatConnectMPEnabled = true
+ case "mobile":
+ req.WeChatConnectMobileEnabled = true
+ default:
+ req.WeChatConnectOpenEnabled = true
+ }
+ }
+ if req.WeChatConnectMode == "" {
+ if req.WeChatConnectMPEnabled {
+ req.WeChatConnectMode = "mp"
+ } else if req.WeChatConnectMobileEnabled {
+ req.WeChatConnectMode = "mobile"
+ } else {
+ req.WeChatConnectMode = "open"
+ }
+ }
+
+ req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID))
+
+ if req.WeChatConnectOpenAppSecret == "" {
+ req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectMPAppSecret == "" {
+ req.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMPAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectMobileAppSecret == "" {
+ req.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectAppSecret == "" {
+ req.WeChatConnectAppSecret = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppSecret, req.WeChatConnectMPAppSecret, req.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret))
+ }
+
+ if req.WeChatConnectOpenEnabled {
+ if req.WeChatConnectOpenAppID == "" {
+ response.BadRequest(c, "WeChat PC App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectOpenAppSecret == "" {
+ response.BadRequest(c, "WeChat PC App Secret is required when enabled")
+ return
+ }
+ }
+ if req.WeChatConnectMPEnabled {
+ if req.WeChatConnectMPAppID == "" {
+ response.BadRequest(c, "WeChat Official Account App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectMPAppSecret == "" {
+ response.BadRequest(c, "WeChat Official Account App Secret is required when enabled")
+ return
+ }
+ }
+ if req.WeChatConnectMobileEnabled {
+ if req.WeChatConnectMobileAppID == "" {
+ response.BadRequest(c, "WeChat Mobile App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectMobileAppSecret == "" {
+ response.BadRequest(c, "WeChat Mobile App Secret is required when enabled")
+ return
+ }
+ }
+
+ if req.WeChatConnectScopes == "" {
+ if req.WeChatConnectMPEnabled {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode("mp")
+ } else {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
+ }
+ }
+ if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
+ if req.WeChatConnectRedirectURL == "" {
+ response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
+ return
+ }
+ if req.WeChatConnectFrontendRedirectURL == "" {
+ req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
+ }
+ if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
+ return
+ }
+ }
+ }
+
// Generic OIDC 参数验证
+ oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
@@ -478,10 +705,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
-
- if req.OIDCConnectProviderName == "" {
- req.OIDCConnectProviderName = "OIDC"
+ req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC"))
+ req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID))
+ req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL))
+ req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL))
+ req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL))
+ req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL))
+ req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL))
+ req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL))
+ req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile"))
+ req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL))
+ req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback"))
+ req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post")))
+ req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256"))
+ req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath))
+ req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath))
+ req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath))
+ if req.OIDCConnectUsePKCE != nil {
+ oidcUsePKCE = *req.OIDCConnectUsePKCE
}
+ if req.OIDCConnectValidateIDToken != nil {
+ oidcValidateIDToken = *req.OIDCConnectValidateIDToken
+ }
+ if req.OIDCConnectClockSkewSeconds == 0 {
+ req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds
+ if req.OIDCConnectClockSkewSeconds == 0 {
+ req.OIDCConnectClockSkewSeconds = 120
+ }
+ }
+
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
@@ -544,19 +796,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
- if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
- response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
- return
- }
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
- if req.OIDCConnectValidateIDToken {
- if req.OIDCConnectAllowedSigningAlgs == "" {
- response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
- return
- }
+ if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
+ response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
+ return
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
@@ -805,6 +1051,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: req.WeChatConnectEnabled,
+ WeChatConnectAppID: req.WeChatConnectAppID,
+ WeChatConnectAppSecret: req.WeChatConnectAppSecret,
+ WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
+ WeChatConnectMPAppID: req.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
+ WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
+ WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
+ WeChatConnectMode: req.WeChatConnectMode,
+ WeChatConnectScopes: req.WeChatConnectScopes,
+ WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
@@ -819,8 +1081,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
- OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
+ OIDCConnectUsePKCE: oidcUsePKCE,
+ OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
@@ -897,6 +1159,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ PaymentVisibleMethodAlipaySource: func() string {
+ if req.PaymentVisibleMethodAlipaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
+ }
+ return previousSettings.PaymentVisibleMethodAlipaySource
+ }(),
+ PaymentVisibleMethodWxpaySource: func() string {
+ if req.PaymentVisibleMethodWxpaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource)
+ }
+ return previousSettings.PaymentVisibleMethodWxpaySource
+ }(),
+ PaymentVisibleMethodAlipayEnabled: func() bool {
+ if req.PaymentVisibleMethodAlipayEnabled != nil {
+ return *req.PaymentVisibleMethodAlipayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodAlipayEnabled
+ }(),
+ PaymentVisibleMethodWxpayEnabled: func() bool {
+ if req.PaymentVisibleMethodWxpayEnabled != nil {
+ return *req.PaymentVisibleMethodWxpayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodWxpayEnabled
+ }(),
+ OpenAIAdvancedSchedulerEnabled: func() bool {
+ if req.OpenAIAdvancedSchedulerEnabled != nil {
+ return *req.OpenAIAdvancedSchedulerEnabled
+ }
+ return previousSettings.OpenAIAdvancedSchedulerEnabled
+ }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
@@ -929,7 +1221,38 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}(),
}
- if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
+ authSourceDefaults := &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
+ },
+ LinuxDo: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
+ },
+ OIDC: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
+ },
+ WeChat: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
+ },
+ ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
+ }
+ if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
@@ -969,7 +1292,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
- h.auditSettingsUpdate(c, previousSettings, settings, req)
+ h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
@@ -977,6 +1300,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
@@ -994,113 +1322,135 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
- RegistrationEnabled: updatedSettings.RegistrationEnabled,
- EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
- PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
- FrontendURL: updatedSettings.FrontendURL,
- InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
- TotpEnabled: updatedSettings.TotpEnabled,
- TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
- SMTPHost: updatedSettings.SMTPHost,
- SMTPPort: updatedSettings.SMTPPort,
- SMTPUsername: updatedSettings.SMTPUsername,
- SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
- SMTPFrom: updatedSettings.SMTPFrom,
- SMTPFromName: updatedSettings.SMTPFromName,
- SMTPUseTLS: updatedSettings.SMTPUseTLS,
- TurnstileEnabled: updatedSettings.TurnstileEnabled,
- TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
- TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
- LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
- LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
- LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
- OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
- OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
- OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
- OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
- OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
- OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
- OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
- OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
- OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
- OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
- OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
- OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
- OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
- OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
- OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
- OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
- OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
- OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
- OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
- OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
- OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
- SiteName: updatedSettings.SiteName,
- SiteLogo: updatedSettings.SiteLogo,
- SiteSubtitle: updatedSettings.SiteSubtitle,
- APIBaseURL: updatedSettings.APIBaseURL,
- ContactInfo: updatedSettings.ContactInfo,
- DocURL: updatedSettings.DocURL,
- HomeContent: updatedSettings.HomeContent,
- HideCcsImportButton: updatedSettings.HideCcsImportButton,
- PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
- PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
- TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
- TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
- CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
- CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
- DefaultConcurrency: updatedSettings.DefaultConcurrency,
- DefaultBalance: updatedSettings.DefaultBalance,
- DefaultSubscriptions: updatedDefaultSubscriptions,
- EnableModelFallback: updatedSettings.EnableModelFallback,
- FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
- FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
- FallbackModelGemini: updatedSettings.FallbackModelGemini,
- FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
- EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
- IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
- OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
- OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
- OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
- OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
- BackendModeEnabled: updatedSettings.BackendModeEnabled,
- EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
- EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
- EnableCCHSigning: updatedSettings.EnableCCHSigning,
- BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
- BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
- AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
- AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
- PaymentEnabled: updatedPaymentCfg.Enabled,
- PaymentMinAmount: updatedPaymentCfg.MinAmount,
- PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
- PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
- PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
- PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
- PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
- PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
- PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
- PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
- PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
- PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
- PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
- PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
- PaymentHelpText: updatedPaymentCfg.HelpText,
- PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
- PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
- PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
- PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
- PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
- })
+ payload := dto.SystemSettings{
+ RegistrationEnabled: updatedSettings.RegistrationEnabled,
+ EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
+ PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
+ FrontendURL: updatedSettings.FrontendURL,
+ InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
+ TotpEnabled: updatedSettings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
+ SMTPHost: updatedSettings.SMTPHost,
+ SMTPPort: updatedSettings.SMTPPort,
+ SMTPUsername: updatedSettings.SMTPUsername,
+ SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
+ SMTPFrom: updatedSettings.SMTPFrom,
+ SMTPFromName: updatedSettings.SMTPFromName,
+ SMTPUseTLS: updatedSettings.SMTPUseTLS,
+ TurnstileEnabled: updatedSettings.TurnstileEnabled,
+ TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
+ LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
+ LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
+ WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
+ WeChatConnectOpenAppID: updatedSettings.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecretConfigured: updatedSettings.WeChatConnectOpenAppSecretConfigured,
+ WeChatConnectMPAppID: updatedSettings.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecretConfigured: updatedSettings.WeChatConnectMPAppSecretConfigured,
+ WeChatConnectMobileAppID: updatedSettings.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecretConfigured: updatedSettings.WeChatConnectMobileAppSecretConfigured,
+ WeChatConnectOpenEnabled: updatedSettings.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: updatedSettings.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: updatedSettings.WeChatConnectMobileEnabled,
+ WeChatConnectMode: updatedSettings.WeChatConnectMode,
+ WeChatConnectScopes: updatedSettings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
+ OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
+ OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
+ OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
+ OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
+ OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
+ OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
+ OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
+ SiteName: updatedSettings.SiteName,
+ SiteLogo: updatedSettings.SiteLogo,
+ SiteSubtitle: updatedSettings.SiteSubtitle,
+ APIBaseURL: updatedSettings.APIBaseURL,
+ ContactInfo: updatedSettings.ContactInfo,
+ DocURL: updatedSettings.DocURL,
+ HomeContent: updatedSettings.HomeContent,
+ HideCcsImportButton: updatedSettings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
+ TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
+ CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
+ CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
+ DefaultConcurrency: updatedSettings.DefaultConcurrency,
+ DefaultBalance: updatedSettings.DefaultBalance,
+ DefaultSubscriptions: updatedDefaultSubscriptions,
+ EnableModelFallback: updatedSettings.EnableModelFallback,
+ FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
+ FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
+ FallbackModelGemini: updatedSettings.FallbackModelGemini,
+ FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
+ EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
+ IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
+ OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
+ OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
+ OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
+ OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: updatedSettings.BackendModeEnabled,
+ EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
+ EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
+ EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled,
+ BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
+ PaymentEnabled: updatedPaymentCfg.Enabled,
+ PaymentMinAmount: updatedPaymentCfg.MinAmount,
+ PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
+ PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
+ PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
+ PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
+ PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
+ PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
+ PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
+ PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
+ PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
+ PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
+ PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
+ PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
+ PaymentHelpText: updatedPaymentCfg.HelpText,
+ PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
+ PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
+ PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
+ PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
+ PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
+ }
+ response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
@@ -1117,12 +1467,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
}
-func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
+func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
- changed := diffSettings(before, after, req)
+ changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req)
if len(changed) == 0 {
return
}
@@ -1137,7 +1487,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
)
}
-func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
+func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
@@ -1205,6 +1555,54 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
+ if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
+ changed = append(changed, "wechat_connect_enabled")
+ }
+ if before.WeChatConnectAppID != after.WeChatConnectAppID {
+ changed = append(changed, "wechat_connect_app_id")
+ }
+ if req.WeChatConnectAppSecret != "" {
+ changed = append(changed, "wechat_connect_app_secret")
+ }
+ if before.WeChatConnectOpenAppID != after.WeChatConnectOpenAppID {
+ changed = append(changed, "wechat_connect_open_app_id")
+ }
+ if req.WeChatConnectOpenAppSecret != "" {
+ changed = append(changed, "wechat_connect_open_app_secret")
+ }
+ if before.WeChatConnectMPAppID != after.WeChatConnectMPAppID {
+ changed = append(changed, "wechat_connect_mp_app_id")
+ }
+ if req.WeChatConnectMPAppSecret != "" {
+ changed = append(changed, "wechat_connect_mp_app_secret")
+ }
+ if before.WeChatConnectMobileAppID != after.WeChatConnectMobileAppID {
+ changed = append(changed, "wechat_connect_mobile_app_id")
+ }
+ if req.WeChatConnectMobileAppSecret != "" {
+ changed = append(changed, "wechat_connect_mobile_app_secret")
+ }
+ if before.WeChatConnectOpenEnabled != after.WeChatConnectOpenEnabled {
+ changed = append(changed, "wechat_connect_open_enabled")
+ }
+ if before.WeChatConnectMPEnabled != after.WeChatConnectMPEnabled {
+ changed = append(changed, "wechat_connect_mp_enabled")
+ }
+ if before.WeChatConnectMobileEnabled != after.WeChatConnectMobileEnabled {
+ changed = append(changed, "wechat_connect_mobile_enabled")
+ }
+ if before.WeChatConnectMode != after.WeChatConnectMode {
+ changed = append(changed, "wechat_connect_mode")
+ }
+ if before.WeChatConnectScopes != after.WeChatConnectScopes {
+ changed = append(changed, "wechat_connect_scopes")
+ }
+ if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL {
+ changed = append(changed, "wechat_connect_redirect_url")
+ }
+ if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL {
+ changed = append(changed, "wechat_connect_frontend_redirect_url")
+ }
if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
changed = append(changed, "oidc_connect_enabled")
}
@@ -1376,6 +1774,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
+ if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
+ changed = append(changed, "payment_visible_method_alipay_source")
+ }
+ if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource {
+ changed = append(changed, "payment_visible_method_wxpay_source")
+ }
+ if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled {
+ changed = append(changed, "payment_visible_method_alipay_enabled")
+ }
+ if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled {
+ changed = append(changed, "payment_visible_method_wxpay_enabled")
+ }
+ if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
+ changed = append(changed, "openai_advanced_scheduler_enabled")
+ }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
@@ -1392,6 +1805,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
+ changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
+ return changed
+}
+
+func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string {
+ if before == nil {
+ before = &service.AuthSourceDefaultSettings{}
+ }
+ if after == nil {
+ after = &service.AuthSourceDefaultSettings{}
+ }
+
+ type providerDefaultGrantField struct {
+ name string
+ before service.ProviderDefaultGrantSettings
+ after service.ProviderDefaultGrantSettings
+ }
+
+ fields := []providerDefaultGrantField{
+ {name: "email", before: before.Email, after: after.Email},
+ {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
+ {name: "oidc", before: before.OIDC, after: after.OIDC},
+ {name: "wechat", before: before.WeChat, after: after.WeChat},
+ }
+ for _, field := range fields {
+ if field.before.Balance != field.after.Balance {
+ changed = append(changed, "auth_source_default_"+field.name+"_balance")
+ }
+ if field.before.Concurrency != field.after.Concurrency {
+ changed = append(changed, "auth_source_default_"+field.name+"_concurrency")
+ }
+ if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) {
+ changed = append(changed, "auth_source_default_"+field.name+"_subscriptions")
+ }
+ if field.before.GrantOnSignup != field.after.GrantOnSignup {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup")
+ }
+ if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
+ }
+ }
+ if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
+ changed = append(changed, "force_email_on_third_party_signup")
+ }
return changed
}
@@ -1412,6 +1869,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
+func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
+ if input == nil {
+ return nil
+ }
+ normalized := normalizeDefaultSubscriptions(*input)
+ return &normalized
+}
+
+func float64ValueOrDefault(value *float64, fallback float64) float64 {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func intValueOrDefault(value *int, fallback int) int {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func boolValueOrDefault(value *bool, fallback bool) bool {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
+ if input == nil {
+ return fallback
+ }
+ result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
+ for _, item := range *input {
+ result = append(result, service.DefaultSubscriptionSetting{
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ })
+ }
+ return result
+}
+
+func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
+ data := make(map[string]any)
+ raw, err := json.Marshal(settings)
+ if err == nil {
+ _ = json.Unmarshal(raw, &data)
+ }
+ if authSourceDefaults == nil {
+ authSourceDefaults = &service.AuthSourceDefaultSettings{}
+ }
+
+ data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
+ data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
+ data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
+ data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
+ data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
+ data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
+ data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
+ data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
+ data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
+ data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
+ data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
+ data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
+ data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
+ data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
+ data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
+ data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
+ data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
+ data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
+ data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
+ data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
+ data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
+
+ return data
+}
+
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
new file mode 100644
index 00000000..9a33a93a
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -0,0 +1,503 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerRepoStub struct {
+ values map[string]string
+ lastUpdates map[string]string
+}
+
+func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.lastUpdates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.lastUpdates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+type failingAuthSourceSettingsRepoStub struct {
+ values map[string]string
+ err error
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok {
+ return s.err
+ }
+ for key, value := range settings {
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
+
+ handler.GetSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 9.5, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+
+ subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
+ require.True(t, ok)
+ require.Len(t, subscriptions, 1)
+}
+
+func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
+ require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 12.75, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+}
+
+func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "easypay",
+ "payment_visible_method_wxpay_source": "wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"])
+ require.Equal(t, true, data["payment_visible_method_alipay_enabled"])
+ require.Equal(t, false, data["payment_visible_method_wxpay_enabled"])
+ require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
+}
+
+func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodAlipaySource: "",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ service.SettingPaymentVisibleMethodWxpaySource: "",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": false,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+}
+
+func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyOIDCConnectEnabled: "true",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "oidc-client",
+ service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
+ service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
+ service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
+ service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
+ service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
+ service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
+ service.SettingKeyOIDCConnectValidateIDToken: "true",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "oidc_connect_enabled": true,
+ "oidc_connect_use_pkce": false,
+ "oidc_connect_validate_id_token": false,
+ "oidc_connect_allowed_signing_algs": "",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, data["oidc_connect_use_pkce"])
+ require.Equal(t, false, data["oidc_connect_validate_id_token"])
+}
+
+func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyOIDCConnectEnabled: "true",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "oidc-client",
+ service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
+ service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
+ service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
+ service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
+ service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
+ service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ service.SettingKeyOIDCConnectRequireEmailVerified: "false",
+ service.SettingKeyOIDCConnectUserInfoEmailPath: "",
+ service.SettingKeyOIDCConnectUserInfoIDPath: "",
+ service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{
+ Default: config.DefaultConfig{UserConcurrency: 5},
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ Scopes: "openid email profile",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ })
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "oidc_connect_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
+}
+
+func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "bogus",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
+}
+
+func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &failingAuthSourceSettingsRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ },
+ err: errors.New("write auth source defaults failed"),
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled])
+ require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+}
+
+func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) {
+ changed := diffSettings(
+ &service.SystemSettings{},
+ &service.SystemSettings{},
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 0,
+ Concurrency: 5,
+ Subscriptions: nil,
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: false,
+ },
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 12.5,
+ Concurrency: 7,
+ Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ },
+ UpdateSettingsRequest{},
+ )
+
+ require.Contains(t, changed, "auth_source_default_email_balance")
+ require.Contains(t, changed, "auth_source_default_email_concurrency")
+ require.Contains(t, changed, "auth_source_default_email_subscriptions")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_signup")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind")
+ require.Contains(t, changed, "force_email_on_third_party_signup")
+}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 1453bd07..b2ed9d18 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -66,6 +66,22 @@ type UpdateBalanceRequest struct {
Notes string `json:"notes"`
}
+type BindUserAuthIdentityRequest struct {
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ Issuer *string `json:"issuer"`
+ Metadata map[string]any `json:"metadata"`
+ Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
+}
+
+type BindUserAuthIdentityChannelRequest struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+}
+
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
@@ -172,6 +188,45 @@ func (h *UserHandler) GetByID(c *gin.Context) {
response.Success(c, dto.UserFromServiceAdmin(user))
}
+// BindAuthIdentity manually binds a canonical auth identity to a user.
+// POST /api/v1/admin/users/:id/auth-identities
+func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ var req BindUserAuthIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ input := service.AdminBindAuthIdentityInput{
+ ProviderType: req.ProviderType,
+ ProviderKey: req.ProviderKey,
+ ProviderSubject: req.ProviderSubject,
+ Issuer: req.Issuer,
+ Metadata: req.Metadata,
+ }
+ if req.Channel != nil {
+ input.Channel = &service.AdminBindAuthIdentityChannelInput{
+ Channel: req.Channel.Channel,
+ ChannelAppID: req.Channel.ChannelAppID,
+ ChannelSubject: req.Channel.ChannelSubject,
+ Metadata: req.Channel.Metadata,
+ }
+ }
+
+ result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go
new file mode 100644
index 00000000..bfba2408
--- /dev/null
+++ b/backend/internal/handler/admin/user_handler_activity_test.go
@@ -0,0 +1,114 @@
+//go:build unit
+
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 7,
+ Email: "activity@example.com",
+ Username: "activity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity",
+ nil,
+ )
+
+ handler.List(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy)
+ require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder)
+ require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Items []struct {
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"items"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Items, 1)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second)
+}
+
+func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 8,
+ Email: "detail@example.com",
+ Username: "detail-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Params = gin.Params{{Key: "id", Value: "8"}}
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil)
+
+ handler.GetByID(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go
new file mode 100644
index 00000000..cb3e4ba5
--- /dev/null
+++ b/backend/internal/handler/auth_current_user_test.go
@@ -0,0 +1,86 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 31,
+ Email: "me@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-31",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+
+ handler := &AuthHandler{
+ userService: service.NewUserService(repo, nil, nil, nil),
+ }
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
+
+ handler.GetCurrentUser(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index f4ddf890..dc68a466 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -1,11 +1,13 @@
package handler
import (
+ "context"
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -76,9 +78,24 @@ type AuthResponse struct {
User *dto.User `json:"user"`
}
+func ensureLoginUserActive(user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if !user.IsActive() {
+ return service.ErrUserNotActive
+ }
+ return nil
+}
+
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
+ if err := ensureLoginUserActive(user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
@@ -104,6 +121,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
})
}
+func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
+ if h == nil || !h.isBackendModeEnabled(ctx) {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ settings, err := h.settingSvc.GetPublicSettings(ctx)
+ if err == nil && settings != nil {
+ return settings.BackendModeEnabled
+ }
+ return h.settingSvc.IsBackendModeEnabled(ctx)
+}
+
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
@@ -177,6 +222,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
@@ -194,11 +244,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
- // Backend mode: only admin can login
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
- return
- }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
h.respondWithTokenPair(c, user)
}
@@ -262,16 +308,80 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
-
- // Backend mode: only admin can login (check BEFORE deleting session)
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
+ if err := ensureLoginUserActive(user); err != nil {
+ response.ErrorFrom(c, err)
return
}
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if session.PendingOAuthBind != nil {
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ pendingSession, err := pendingSvc.GetBrowserSession(
+ c.Request.Context(),
+ session.PendingOAuthBind.PendingSessionToken,
+ session.PendingOAuthBind.BrowserSessionKey,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{})
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(
+ c.Request.Context(),
+ h.entClient(),
+ h.authService,
+ h.userService,
+ pendingSession,
+ decision,
+ &user.ID,
+ true,
+ true,
+ ); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(
+ c.Request.Context(),
+ pendingSession.SessionToken,
+ pendingSession.BrowserSessionKey,
+ ); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+
+ user, err = h.userService.GetByID(c.Request.Context(), session.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
+ if session.PendingOAuthBind == nil {
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ }
+
h.respondWithTokenPair(c, user)
}
@@ -290,8 +400,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
+ identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
type UserResponse struct {
- *dto.User
+ userProfileResponse
RunMode string `json:"run_mode"`
}
@@ -300,7 +416,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
- response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
+ response.Success(c, UserResponse{
+ userProfileResponse: userProfileResponseFromService(user, identities),
+ RunMode: runMode,
+ })
}
// ValidatePromoCodeRequest 验证优惠码请求
@@ -578,6 +697,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
+ h.consumePendingOAuthSessionOnLogout(c)
+ clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
@@ -598,7 +719,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
return
}
- if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
+ if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 0c7c2da7..2ef05963 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -2,6 +2,8 @@ package handler
import (
"context"
+ "crypto/hmac"
+ "crypto/sha256"
"encoding/base64"
"errors"
"fmt"
@@ -13,10 +15,13 @@ import (
"time"
"unicode/utf8"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,17 +30,24 @@ import (
)
const (
- linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
- linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
- linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
- linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
- linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
- linuxDoOAuthDefaultRedirectTo = "/dashboard"
- linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
+ linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
+ oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth"
+ linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
+ linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
+ linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
+ linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent"
+ linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user"
+ oauthBindAccessTokenCookieName = "oauth_bind_access_token"
+ linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
+ linuxDoOAuthDefaultRedirectTo = "/dashboard"
+ linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
+
+ oauthIntentLogin = "login"
+ oauthIntentBindCurrentUser = "bind_current_user"
)
type linuxDoTokenResponse struct {
@@ -87,9 +99,29 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
+ }
codeChallenge := ""
if cfg.UsePKCE {
@@ -148,6 +180,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
@@ -161,6 +195,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+ intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
@@ -198,52 +239,204 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
+ email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
+ compatEmail := strings.TrimSpace(email)
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
-
- // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
- if err != nil {
- if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
- return
- }
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ identityKey := service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ }
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
return
}
- // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
- redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityKey,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityKey,
+ TargetUserID: &existingIdentityUser.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createLinuxDoOAuthChoicePendingSession(
+ c,
+ identityKey,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ h.isForceEmailOnThirdPartySignup(c.Request.Context()),
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ switch len(userEntity) {
+ case 0:
+ return nil, nil
+ case 1:
+ return userEntity[0], nil
+ default:
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+}
+
+func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
}
type completeLinuxDoOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -256,17 +449,87 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
-
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -303,7 +566,7 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
@@ -353,11 +616,11 @@ func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
-) (email string, username string, subject string, err error) {
+) (email string, username string, subject string, displayName string, avatarURL string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
- return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
+ return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
@@ -366,16 +629,16 @@ func linuxDoFetchUserInfo(
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
- return "", "", "", fmt.Errorf("request userinfo: %w", err)
+ return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
- return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
+ return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
-func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
+func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
@@ -400,12 +663,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
getGJSON(body, "user.id"),
)
+ displayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "user.name"),
+ getGJSON(body, "user.username"),
+ username,
+ )
+ avatarURL = firstNonEmpty(
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "picture"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
+
subject = strings.TrimSpace(subject)
if subject == "" {
- return "", "", "", errors.New("userinfo missing id field")
+ return "", "", "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
- return "", "", "", errors.New("userinfo returned invalid id field")
+ return "", "", "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
@@ -418,8 +698,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
if username == "" {
username = "linuxdo_" + subject
}
+ displayName = strings.TrimSpace(displayName)
+ if displayName == "" {
+ displayName = username
+ }
+ avatarURL = strings.TrimSpace(avatarURL)
- return email, username, subject, nil
+ return email, username, subject, displayName, avatarURL, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
@@ -436,7 +721,7 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
@@ -670,6 +955,30 @@ func clearCookie(c *gin.Context, name string, secure bool) {
})
}
+func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthBindAccessTokenCookieName,
+ Value: "",
+ Path: oauthBindAccessTokenCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthBindAccessTokenCookieName,
+ Value: url.QueryEscape(strings.TrimSpace(token)),
+ Path: oauthBindAccessTokenCookiePath,
+ MaxAge: linuxDoOAuthCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
@@ -728,3 +1037,127 @@ func linuxDoSyntheticEmail(subject string) string {
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
+
+func normalizeOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", oauthIntentLogin:
+ return oauthIntentLogin
+ case "bind", oauthIntentBindCurrentUser:
+ return oauthIntentBindCurrentUser
+ default:
+ return oauthIntentLogin
+ }
+}
+
+func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) {
+ userID, err := h.resolveOAuthBindTargetUserID(c)
+ if err != nil || userID == nil || *userID <= 0 {
+ return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")
+ }
+ return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
+}
+
+func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
+ const bearerPrefix = "Bearer "
+
+ authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
+ if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
+ response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
+ return
+ }
+
+ token := strings.TrimSpace(authHeader[len(bearerPrefix):])
+ if token == "" {
+ response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
+ return
+ }
+
+ setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
+ c.Status(http.StatusNoContent)
+ c.Writer.WriteHeaderNow()
+}
+
+func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
+ if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
+ return &subject.UserID, nil
+ }
+ if h == nil || h.authService == nil || h.userService == nil {
+ return nil, service.ErrInvalidToken
+ }
+
+ ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName)
+ clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c))
+ if err != nil {
+ return nil, err
+ }
+
+ tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value))
+ if err != nil {
+ return nil, err
+ }
+ if tokenString == "" {
+ return nil, service.ErrInvalidToken
+ }
+
+ claims, err := h.authService.ValidateToken(tokenString)
+ if err != nil {
+ return nil, err
+ }
+ user, err := h.userService.GetByID(c.Request.Context(), claims.UserID)
+ if err != nil {
+ return nil, err
+ }
+ if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion {
+ return nil, service.ErrInvalidToken
+ }
+ return &user.ID, nil
+}
+
+func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) {
+ value, err := readCookieDecoded(c, cookieName)
+ if err != nil {
+ return 0, err
+ }
+ return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret())
+}
+
+func (h *AuthHandler) oauthBindCookieSecret() string {
+ if h == nil || h.cfg == nil {
+ return ""
+ }
+ return strings.TrimSpace(h.cfg.JWT.Secret)
+}
+
+func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) {
+ secret = strings.TrimSpace(secret)
+ if userID <= 0 || secret == "" {
+ return "", errors.New("invalid oauth bind cookie input")
+ }
+ payload := strconv.FormatInt(userID, 10)
+ mac := hmac.New(sha256.New, []byte(secret))
+ _, _ = mac.Write([]byte(payload))
+ signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ return payload + "." + signature, nil
+}
+
+func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) {
+ secret = strings.TrimSpace(secret)
+ if secret == "" {
+ return 0, errors.New("missing oauth bind cookie secret")
+ }
+ payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".")
+ if !ok || payload == "" || signature == "" {
+ return 0, errors.New("invalid oauth bind cookie")
+ }
+ mac := hmac.New(sha256.New, []byte(secret))
+ _, _ = mac.Write([]byte(payload))
+ expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
+ return 0, errors.New("invalid oauth bind cookie signature")
+ }
+ userID, err := strconv.ParseInt(payload, 10, 64)
+ if err != nil || userID <= 0 {
+ return 0, errors.New("invalid oauth bind cookie user")
+ }
+ return userID, nil
+}
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index ff169c52..8b01ab41 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -1,10 +1,24 @@
package handler
import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
"strings"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -41,11 +55,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "Alice", displayName)
+ require.Equal(t, "https://cdn.example/avatar.png", avatarURL)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
@@ -53,11 +69,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "linuxdo_123", displayName)
+ require.Equal(t, "", avatarURL)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
@@ -65,11 +83,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
+ _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
- _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
+ _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
@@ -106,3 +124,906 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
+
+func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
+ handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ c.Request = req
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42})
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.Contains(t, location, "connect.linux.do/oauth/authorize")
+ require.Contains(t, location, "client_id=linuxdo-client")
+ require.Contains(t, location, "code_challenge=")
+
+ cookies := recorder.Result().Cookies()
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName))
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie))
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie))
+ require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
+
+ intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName)
+ require.NotNil(t, intentCookie)
+ require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
+
+ bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, int64(42), userID)
+}
+
+func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
+ handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
+}
+
+func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, r.ParseForm())
+ require.Empty(t, r.PostForm.Get("code_verifier"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+ require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+}
+
+func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("bind-cookie@example.com").
+ SetUsername("bind-cookie-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(context.Background())
+ require.NoError(t, err)
+
+ token, err := handler.authService.GenerateToken(&service.User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ PasswordHash: user.PasswordHash,
+ Role: user.Role,
+ Status: user.Status,
+ })
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath})
+ c.Request = req
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+
+ bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, user.ID, userID)
+
+ accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
+ require.NotNil(t, accessTokenCookie)
+ require.Equal(t, -1, accessTokenCookie.MaxAge)
+}
+
+func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
+ req.Header.Set("Authorization", "Bearer access-token-value")
+ c.Request = req
+
+ handler.PrepareOAuthBindAccessTokenCookie(c)
+
+ require.Equal(t, http.StatusNoContent, recorder.Code)
+ accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
+ require.NotNil(t, accessTokenCookie)
+ require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
+ require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
+ require.True(t, accessTokenCookie.HttpOnly)
+ require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
+}
+
+func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(linuxDoSyntheticEmail("321")).
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("321").
+ SetMetadata(map[string]any{"username": "legacy-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail)
+ require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+ require.Nil(t, completion["error"])
+}
+
+func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(linuxDoSyntheticEmail("654")).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("654").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(" Legacy@Example.com ").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
+ require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
+ require.Equal(t, true, completion["existing_account_bindable"])
+ require.Equal(t, "compat_email_match", completion["choice_reason"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Nil(t, session.TargetUserID)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser))
+ req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Empty(t, completion["access_token"])
+ require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, userCount)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-1").
+ SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Display",
+ "suggested_avatar_url": "https://cdn.example/linuxdo.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "LinuxDo Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("linuxdo-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-choice-subject-1").
+ SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-no-adoption").
+ SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Legacy",
+ "suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "linuxdo_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingOwner.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-conflict-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-conflict-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-conflict-subject").
+ SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-conflict-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
+
+ userCount, err := client.User.Query().
+ Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
+ t.Helper()
+ handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
+ return handler
+}
+
+func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+ handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
+ handler.settingSvc = nil
+ handler.cfg = &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ LinuxDo: oauthCfg,
+ }
+ return handler, client
+}
diff --git a/backend/internal/handler/auth_oauth_logout_test.go b/backend/internal/handler/auth_oauth_logout_test.go
new file mode 100644
index 00000000..0d4f94b1
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_logout_test.go
@@ -0,0 +1,68 @@
+package handler
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("logout-pending-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("logout-subject-123").
+ SetBrowserSessionKey("logout-browser-session-key").
+ SetResolvedEmail("logout@example.com").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
+ req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
+ req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
+ req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
+ req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
+ ginCtx.Request = req
+
+ handler.Logout(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ cookies := recorder.Result().Cookies()
+ for _, name := range []string{
+ oauthPendingSessionCookieName,
+ oauthPendingBrowserCookieName,
+ oauthBindAccessTokenCookieName,
+ linuxDoOAuthStateCookieName,
+ oidcOAuthStateCookieName,
+ wechatOAuthStateCookieName,
+ wechatPaymentOAuthStateName,
+ } {
+ cookie := findCookie(cookies, name)
+ require.NotNil(t, cookie, name)
+ require.Equal(t, -1, cookie.MaxAge, name)
+ require.True(t, cookie.HttpOnly, name)
+ }
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
new file mode 100644
index 00000000..604ad903
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -0,0 +1,1944 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ entsql "entgo.io/ent/dialect/sql"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
+ oauthPendingBrowserCookieName = "oauth_pending_browser_session"
+ oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
+ oauthPendingSessionCookieName = "oauth_pending_session"
+ oauthPendingCookieMaxAgeSec = 10 * 60
+ oauthPendingChoiceStep = "choose_account_action_required"
+
+ oauthCompletionResponseKey = "completion_response"
+)
+
+var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
+
+type oauthPendingSessionPayload struct {
+ Intent string
+ Identity service.PendingAuthIdentityKey
+ TargetUserID *int64
+ ResolvedEmail string
+ RedirectTo string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ CompletionResponse map[string]any
+}
+
+type oauthAdoptionDecisionRequest struct {
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type bindPendingOAuthLoginRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type createPendingOAuthAccountRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code,omitempty"`
+ Password string `json:"password" binding:"required,min=6"`
+ InvitationCode string `json:"invitation_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type sendPendingOAuthVerifyCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ TurnstileToken string `json:"turnstile_token,omitempty"`
+ PendingAuthToken string `json:"pending_auth_token,omitempty"`
+ PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
+}
+
+func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
+ if h == nil || h.authService == nil || h.authService.EntClient() == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil
+}
+
+func generateOAuthPendingBrowserSession() (string, error) {
+ return oauth.GenerateState()
+}
+
+func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: encodeCookieValue(sessionKey),
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: "",
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingBrowserCookieName)
+}
+
+func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: encodeCookieValue(sessionToken),
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: "",
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingSessionCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingSessionCookieName)
+}
+
+func redirectToFrontendCallback(c *gin.Context, frontendCallback string) {
+ u, err := url.Parse(frontendCallback)
+ if err != nil {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ u.Fragment = ""
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.Redirect(http.StatusFound, u.String())
+}
+
+func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error {
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return err
+ }
+
+ session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
+ Intent: strings.TrimSpace(payload.Intent),
+ Identity: payload.Identity,
+ TargetUserID: payload.TargetUserID,
+ ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
+ RedirectTo: strings.TrimSpace(payload.RedirectTo),
+ BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
+ UpstreamIdentityClaims: payload.UpstreamIdentityClaims,
+ LocalFlowState: map[string]any{
+ oauthCompletionResponseKey: payload.CompletionResponse,
+ },
+ })
+ if err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
+ }
+
+ setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c))
+ return nil
+}
+
+func readCompletionResponse(session map[string]any) (map[string]any, bool) {
+ if len(session) == 0 {
+ return nil, false
+ }
+ value, ok := session[oauthCompletionResponseKey]
+ if !ok {
+ return nil, false
+ }
+ result, ok := value.(map[string]any)
+ if !ok {
+ return nil, false
+ }
+ return result, true
+}
+
+func clonePendingMap(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ merged := clonePendingMap(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := merged["redirect"]; !exists {
+ merged["redirect"] = session.RedirectTo
+ }
+ }
+ for key, value := range overrides {
+ if value == nil {
+ delete(merged, key)
+ continue
+ }
+ merged[key] = value
+ }
+ applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
+ return merged
+}
+
+func pendingSessionStringValue(values map[string]any, key string) string {
+ if len(values) == 0 {
+ return ""
+ }
+ raw, ok := values[key]
+ if !ok {
+ return ""
+ }
+ value, ok := raw.(string)
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(value)
+}
+
+func pendingSessionWantsInvitation(payload map[string]any) bool {
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
+}
+
+func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
+ if session == nil {
+ return false
+ }
+ if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
+ return false
+ }
+ if session.TargetUserID == nil || *session.TargetUserID <= 0 {
+ return false
+ }
+ if pendingSessionWantsInvitation(payload) {
+ return false
+ }
+ return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
+}
+
+func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
+ if session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if strings.TrimSpace(session.Intent) != oauthIntentLogin {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ return nil
+}
+
+func buildLegacyCompleteRegistrationPendingResponse(
+ session *dbent.PendingAuthSession,
+ forceEmailOnSignup bool,
+ emailVerificationRequired bool,
+) map[string]any {
+ completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ }))
+
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ if _, exists := completionResponse["email"]; !exists {
+ completionResponse["email"] = email
+ }
+ if _, exists := completionResponse["resolved_email"]; !exists {
+ completionResponse["resolved_email"] = email
+ }
+ }
+ if _, exists := completionResponse["choice_reason"]; !exists {
+ switch {
+ case forceEmailOnSignup:
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ case emailVerificationRequired:
+ completionResponse["choice_reason"] = "email_verification_required"
+ default:
+ completionResponse["choice_reason"] = "third_party_signup"
+ }
+ }
+ return completionResponse
+}
+
+func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
+ c *gin.Context,
+ session *dbent.PendingAuthSession,
+) (*dbent.PendingAuthSession, bool, error) {
+ if session == nil {
+ return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
+ if step := pendingSessionStringValue(payload, "step"); step != "" {
+ return session, true, nil
+ }
+
+ emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
+ forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
+ if !emailVerificationRequired && !forceEmailOnSignup {
+ return session, false, nil
+ }
+
+ client := h.entClient()
+ if client == nil {
+ return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ updatedSession, err := updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ strings.TrimSpace(session.Intent),
+ strings.TrimSpace(session.ResolvedEmail),
+ nil,
+ buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
+ )
+ if err != nil {
+ return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+ return updatedSession, true, nil
+}
+
+func (r oauthAdoptionDecisionRequest) hasDecision() bool {
+ return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
+}
+
+func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
+ var req oauthAdoptionDecisionRequest
+ if c == nil || c.Request == nil || c.Request.Body == nil {
+ return req, nil
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ if errors.Is(err, io.EOF) {
+ return req, nil
+ }
+ return req, err
+ }
+ return req, nil
+}
+
+func cloneOAuthMetadata(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
+ merged := cloneOAuthMetadata(base)
+ for key, value := range overlay {
+ merged[key] = value
+ }
+ return merged
+}
+
+func normalizeAdoptedOAuthDisplayName(value string) string {
+ value = strings.TrimSpace(value)
+ if len([]rune(value)) > 100 {
+ value = string([]rune(value)[:100])
+ }
+ return value
+}
+
+func (h *AuthHandler) entClient() *dbent.Client {
+ if h == nil || h.authService == nil {
+ return nil
+ }
+ return h.authService.EntClient()
+}
+
+func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
+ if err != nil || defaults == nil {
+ return false
+ }
+ return defaults.ForceEmailOnThirdPartySignup
+}
+
+func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ record, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return findActiveUserByID(ctx, client, record.UserID)
+}
+
+func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
+func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
+func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
+func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
+
+func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "linuxdo")
+}
+
+func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
+
+func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "wechat")
+}
+
+func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "")
+}
+
+// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
+// pending OAuth account-creation flow.
+// POST /api/v1/auth/oauth/pending/send-verify-code
+func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
+ var req sendPendingOAuthVerifyCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ _, session, _, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ } else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, SendVerifyCodeResponse{
+ Message: "Verification code sent successfully",
+ Countdown: result.Countdown,
+ })
+}
+
+func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ existing, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
+ Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
+ }
+ if existing != nil && !req.hasDecision() {
+ return existing, nil
+ }
+ if existing == nil && !req.hasDecision() {
+ return nil, nil
+ }
+
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if existing != nil {
+ input.AdoptDisplayName = existing.AdoptDisplayName
+ input.AdoptAvatar = existing.AdoptAvatar
+ input.IdentityID = existing.IdentityID
+ }
+ if req.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *req.AdoptDisplayName
+ }
+ if req.AdoptAvatar != nil {
+ input.AdoptAvatar = *req.AdoptAvatar
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
+ if err != nil {
+ return nil, err
+ }
+ if decision != nil {
+ return decision, nil
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ })
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func updatePendingOAuthSessionProgress(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ intent string,
+ resolvedEmail string,
+ targetUserID *int64,
+ completionResponse map[string]any,
+) (*dbent.PendingAuthSession, error) {
+ if client == nil || session == nil {
+ return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+
+ localFlowState := clonePendingMap(session.LocalFlowState)
+ localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
+
+ update := client.PendingAuthSession.UpdateOneID(session.ID).
+ SetIntent(strings.TrimSpace(intent)).
+ SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
+ SetLocalFlowState(localFlowState)
+ if targetUserID != nil && *targetUserID > 0 {
+ update = update.SetTargetUserID(*targetUserID)
+ } else {
+ update = update.ClearTargetUserID()
+ }
+ return update.Save(ctx)
+}
+
+func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
+ if session == nil {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return *session.TargetUserID, nil
+ }
+ email := strings.TrimSpace(session.ResolvedEmail)
+ if email == "" {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
+ }
+ return 0, err
+ }
+ return userEntity.ID, nil
+}
+
+func userNormalizedEmailPredicate(email string) predicate.User {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.P(func(b *entsql.Builder) {
+ b.WriteString("LOWER(TRIM(").
+ Ident(s.C(dbuser.FieldEmail)).
+ WriteString(")) = ").
+ Arg(normalized)
+ }))
+ })
+}
+
+func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) {
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ matches, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+ return matches[0], nil
+}
+
+func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
+ if client == nil || session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity == nil || identity.UserID <= 0 {
+ return nil
+ }
+
+ activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
+ if err != nil {
+ return err
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return nil
+}
+
+func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
+ if session == nil {
+ return nil
+ }
+ switch strings.TrimSpace(session.ProviderType) {
+ case "oidc":
+ issuer := strings.TrimSpace(session.ProviderKey)
+ if issuer == "" {
+ issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ }
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ default:
+ issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ }
+}
+
+func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
+ return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
+ }
+
+ client := tx.Client()
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if identity != nil {
+ if identity.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
+ if err != nil {
+ return nil, err
+ }
+ if activeOwner != nil {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return client.AuthIdentity.UpdateOneID(identity.ID).
+ SetUserID(userID).
+ Save(ctx)
+ }
+ return identity, nil
+ }
+
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(strings.TrimSpace(session.ProviderType)).
+ SetProviderKey(strings.TrimSpace(session.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
+ SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ return create.Save(ctx)
+}
+
+func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ client := tx.Client()
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(providerKey)
+ channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
+ channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
+ channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
+ metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
+
+ identityRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+
+ var legacyOpenIDIdentity *dbent.AuthIdentity
+ if channelSubject != "" && channelSubject != providerSubject {
+ legacyOpenIDRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(channelSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ switch {
+ case identity != nil:
+ update := client.AuthIdentity.UpdateOneID(identity.ID).
+ SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
+ if identity.UserID != userID {
+ update = update.SetUserID(userID)
+ }
+ if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey {
+ update = update.SetProviderKey(providerKey)
+ }
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ case legacyOpenIDIdentity != nil:
+ update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if channel == "" || channelAppID == "" || channelSubject == "" {
+ return identity, nil
+ }
+
+ channelRecords, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+
+ channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
+ if channelRecord == nil {
+ if _, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetChannel(channel).
+ SetChannelAppID(channelAppID).
+ SetChannelSubject(channelSubject).
+ SetMetadata(channelMetadata).
+ Save(ctx); err != nil {
+ return nil, err
+ }
+ return identity, nil
+ }
+
+ updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
+ SetIdentityID(identity.ID).
+ SetMetadata(channelMetadata)
+ if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey {
+ updateChannel = updateChannel.SetProviderKey(providerKey)
+ }
+ _, err = updateChannel.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return identity, nil
+}
+
+func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
+ var preferred *dbent.AuthIdentity
+ var fallback *dbent.AuthIdentity
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, record.UserID)
+ if err != nil {
+ return nil, false, err
+ }
+ if activeOwner != nil {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
+ var preferred *dbent.AuthIdentityChannel
+ var fallback *dbent.AuthIdentityChannel
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID)
+ if err != nil {
+ return nil, false, err
+ }
+ if activeOwner != nil {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) {
+ if client == nil || userID <= 0 {
+ return nil, nil
+ }
+ userEntity, err := client.User.Get(ctx, userID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
+ }
+ if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
+ return nil, service.ErrUserNotActive
+ }
+ return userEntity, nil
+}
+
+func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
+ if channel == nil {
+ return map[string]any{}
+ }
+ return cloneOAuthMetadata(channel.Metadata)
+}
+
+func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
+ if session == nil || decision == nil {
+ return false
+ }
+ switch strings.ToLower(strings.TrimSpace(session.Intent)) {
+ case "bind_current_user", "login", "adopt_existing_user_by_email":
+ return true
+ default:
+ return decision.AdoptDisplayName || decision.AdoptAvatar
+ }
+}
+
+func shouldSkipAvatarAdoption(err error) bool {
+ return errors.Is(err, service.ErrAvatarInvalid) ||
+ errors.Is(err, service.ErrAvatarTooLarge) ||
+ errors.Is(err, service.ErrAvatarNotImage)
+}
+
+func applyPendingOAuthBinding(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if client == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthBindingTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if tx == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
+ targetUserID := int64(0)
+ if overrideUserID != nil && *overrideUserID > 0 {
+ targetUserID = *overrideUserID
+ } else {
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
+ if err != nil {
+ return err
+ }
+ targetUserID = resolvedUserID
+ }
+
+ adoptedDisplayName := ""
+ if decision != nil && decision.AdoptDisplayName {
+ adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
+ }
+ adoptedAvatarURL := ""
+ if decision != nil && decision.AdoptAvatar {
+ adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
+ }
+ shouldAdoptAvatar := false
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
+ shouldAdoptAvatar = true
+ } else if !shouldSkipAvatarAdoption(err) {
+ return err
+ }
+ }
+
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if err := tx.Client().User.UpdateOneID(targetUserID).
+ SetUsername(adoptedDisplayName).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ metadata := cloneOAuthMetadata(identity.Metadata)
+ for key, value := range session.UpstreamIdentityClaims {
+ metadata[key] = value
+ }
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
+ metadata["display_name"] = adoptedDisplayName
+ }
+ if shouldAdoptAvatar {
+ metadata["avatar_url"] = adoptedAvatarURL
+ }
+
+ updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ if _, err := updateIdentity.Save(ctx); err != nil {
+ return err
+ }
+
+ if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
+ if _, err := tx.Client().IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(identity.ID),
+ identityadoptiondecision.IDNEQ(decision.ID),
+ ).
+ ClearIdentityID().
+ Save(ctx); err != nil {
+ return err
+ }
+ if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
+ SetIdentityID(identity.ID).
+ Save(ctx); err != nil {
+ return err
+ }
+ }
+
+ if applyFirstBindDefaults && authService != nil {
+ if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
+ return err
+ }
+ }
+
+ if shouldAdoptAvatar && userService != nil {
+ if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func consumePendingOAuthBrowserSessionTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ session *dbent.PendingAuthSession,
+) error {
+ if tx == nil || session == nil {
+ return service.ErrPendingAuthSessionNotFound
+ }
+
+ storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return service.ErrPendingAuthSessionNotFound
+ }
+ return err
+ }
+
+ now := time.Now().UTC()
+ if storedSession.ConsumedAt != nil {
+ return service.ErrPendingAuthSessionConsumed
+ }
+ if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
+ return service.ErrPendingAuthSessionExpired
+ }
+ if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
+ strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return service.ErrPendingAuthBrowserMismatch
+ }
+
+ if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func applyPendingOAuthAdoptionAndConsumeSession(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ userID int64,
+) error {
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ if session == nil || userID <= 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
+ return err
+ }
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ return applyPendingOAuthBinding(
+ ctx,
+ client,
+ authService,
+ userService,
+ session,
+ decision,
+ overrideUserID,
+ false,
+ strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"),
+ )
+}
+
+func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
+ if len(payload) == 0 || len(upstream) == 0 {
+ return
+ }
+
+ displayName := pendingSessionStringValue(upstream, "suggested_display_name")
+ avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url")
+
+ if displayName != "" {
+ if _, exists := payload["suggested_display_name"]; !exists {
+ payload["suggested_display_name"] = displayName
+ }
+ }
+ if avatarURL != "" {
+ if _, exists := payload["suggested_avatar_url"]; !exists {
+ payload["suggested_avatar_url"] = avatarURL
+ }
+ }
+ if displayName != "" || avatarURL != "" {
+ payload["adoption_required"] = true
+ }
+}
+
+func pendingOAuthIdentityExistsForUser(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ userID int64,
+) (bool, error) {
+ if client == nil || session == nil || userID <= 0 {
+ return false, nil
+ }
+
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ if providerType == "" || providerSubject == "" {
+ return false, nil
+ }
+
+ query := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ authidentity.UserIDEQ(userID),
+ )
+ if strings.EqualFold(providerType, "wechat") {
+ query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...))
+ } else if providerKey != "" {
+ query = query.Where(authidentity.ProviderKeyEQ(providerKey))
+ }
+
+ count, err := query.Count(ctx)
+ if err != nil {
+ return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return count > 0, nil
+}
+
+func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ payload map[string]any,
+) (bool, error) {
+ if session == nil || len(payload) == 0 {
+ return false, nil
+ }
+ if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
+ return false, nil
+ }
+ if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
+ pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" {
+ return false, nil
+ }
+
+ return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID)
+}
+
+func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ return svc, session, clearCookies, nil
+}
+
+func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
+ if c == nil || c.Request == nil {
+ return
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return
+ }
+ _, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+}
+
+func clearOAuthLogoutCookies(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ clearOAuthBindAccessTokenCookie(c, secureCookie)
+
+ clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
+
+ oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+
+ wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
+}
+
+func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
+ completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
+ payload := gin.H{
+ "auth_result": "pending_session",
+ "provider": strings.TrimSpace(session.ProviderType),
+ "intent": strings.TrimSpace(session.Intent),
+ }
+ for key, value := range completionResponse {
+ payload[key] = value
+ }
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ payload["email"] = email
+ }
+ return payload
+}
+
+func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
+ normalized := clonePendingMap(payload)
+ for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
+ delete(normalized, key)
+ }
+ step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
+ switch step {
+ case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
+ normalized["step"] = oauthPendingChoiceStep
+ }
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
+ normalized["adoption_required"] = true
+ }
+ if _, exists := normalized["adoption_required"]; !exists {
+ if _, hasChoiceFields := normalized["email_binding_required"]; hasChoiceFields {
+ normalized["adoption_required"] = true
+ }
+ }
+ return normalized
+}
+
+func pendingOAuthChoiceCompletionResponse(session *dbent.PendingAuthSession, email string) map[string]any {
+ response := mergePendingCompletionResponse(session, map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "force_email_on_signup": true,
+ "email_binding_required": true,
+ "existing_account_bindable": true,
+ })
+ if email = strings.TrimSpace(email); email != "" {
+ response["email"] = email
+ response["resolved_email"] = email
+ }
+ return response
+}
+
+func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
+ c *gin.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ targetUser *dbent.User,
+ email string,
+) (*dbent.PendingAuthSession, error) {
+ completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
+ var targetUserID *int64
+ if targetUser != nil && targetUser.ID > 0 {
+ targetUserID = &targetUser.ID
+ }
+ session, err := updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ strings.TrimSpace(session.Intent),
+ email,
+ targetUserID,
+ completionResponse,
+ )
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+ return session, nil
+}
+
+func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
+ var req bindPendingOAuthLoginRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
+ response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
+ tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession(
+ c.Request.Context(),
+ user.ID,
+ user.Email,
+ session.SessionToken,
+ session.BrowserSessionKey,
+ )
+ if err != nil {
+ response.InternalError(c, "Failed to create 2FA session")
+ return
+ }
+ response.Success(c, TotpLoginResponse{
+ Requires2FA: true,
+ TempToken: tempToken,
+ UserEmailMasked: service.MaskEmail(user.Email),
+ })
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
+ if err != nil {
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+func respondPendingOAuthBindingApplyError(c *gin.Context, err error) {
+ if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+}
+
+func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
+ var req createPendingOAuthAccountRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if err != nil {
+ switch {
+ case errors.Is(err, service.ErrUserNotFound):
+ existingUser = nil
+ case infraerrors.Code(err) >= http.StatusBadRequest && infraerrors.Code(err) < http.StatusInternalServerError:
+ response.ErrorFrom(c, err)
+ return
+ default:
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
+ return
+ }
+ }
+ if existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
+ c.Request.Context(),
+ email,
+ req.Password,
+ strings.TrimSpace(req.VerifyCode),
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ )
+ if err != nil {
+ if errors.Is(err, service.ErrEmailExists) {
+ existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if lookupErr != nil {
+ response.ErrorFrom(c, lookupErr)
+ return
+ }
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ rollbackCreatedUser := func(originalErr error) bool {
+ if user == nil || user.ID <= 0 {
+ return false
+ }
+ if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
+ c.Request.Context(),
+ user.ID,
+ strings.TrimSpace(req.InvitationCode),
+ ); rollbackErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer(
+ "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
+ "failed to rollback pending oauth account creation",
+ ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
+ return true
+ }
+ user = nil
+ return false
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ tx, err := client.Tx(c.Request.Context())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ defer func() { _ = tx.Rollback() }()
+ txCtx := dbent.NewTxContext(c.Request.Context(), tx)
+
+ if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+
+ if err := h.authService.FinalizeOAuthEmailAccount(
+ txCtx,
+ user,
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ ); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if pendingOAuthCreateAccountPreCommitHook != nil {
+ if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
+// POST /api/v1/auth/oauth/pending/exchange
+func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+ adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
+ if err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ payload, ok := readCompletionResponse(session.LocalFlowState)
+ if !ok {
+ clearCookies()
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
+ return
+ }
+ payload = normalizePendingOAuthCompletionResponse(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := payload["redirect"]; !exists {
+ payload["redirect"] = session.RedirectTo
+ }
+ }
+ applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
+
+ canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
+ var loginUser *service.User
+ if canIssueTokenPair {
+ loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensureLoginUserActive(loginUser); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+ skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if skipAdoptionPrompt {
+ delete(payload, "adoption_required")
+ }
+
+ if pendingSessionWantsInvitation(payload) {
+ if adoptionDecision.hasDecision() {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ _ = decision
+ }
+ response.Success(c, payload)
+ return
+ }
+ if !adoptionDecision.hasDecision() {
+ adoptionRequired, _ := payload["adoption_required"].(bool)
+ if adoptionRequired {
+ response.Success(c, payload)
+ return
+ }
+ }
+
+ decisionReq := adoptionDecision
+ if !decisionReq.hasDecision() {
+ adoptDisplayName := false
+ adoptAvatar := false
+ decisionReq = oauthAdoptionDecisionRequest{
+ AdoptDisplayName: &adoptDisplayName,
+ AdoptAvatar: &adoptAvatar,
+ }
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+
+ if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if canIssueTokenPair {
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
+ if err != nil {
+ clearCookies()
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
+ payload["access_token"] = tokenPair.AccessToken
+ payload["refresh_token"] = tokenPair.RefreshToken
+ payload["expires_in"] = tokenPair.ExpiresIn
+ payload["token_type"] = "Bearer"
+ }
+
+ clearCookies()
+ response.Success(c, payload)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
new file mode 100644
index 00000000..a4b7a297
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -0,0 +1,2995 @@
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/pquerna/otp/totp"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
+ payload := map[string]any{
+ "access_token": "token",
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Alice", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) {
+ payload := map[string]any{
+ "suggested_display_name": "Existing",
+ "adoption_required": false,
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Existing", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
+
+ setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
+
+ cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, cookie)
+ require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
+}
+
+func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("linuxdo-123@linuxdo-connect.invalid").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "https://cdn.example/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Alice Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.NotNil(t, avatar)
+ require.Equal(t, "remote_url", avatar.StorageProvider)
+ require.Equal(t, "https://cdn.example/alice.png", avatar.URL)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("invalid-avatar@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-invalid-avatar-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invalid-avatar-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-invalid-avatar-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "/avatars/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invalid-avatar-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ _, hasAdoptedAvatar := identity.Metadata["avatar_url"]
+ require.False(t, hasAdoptedAvatar)
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.Nil(t, avatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-pending-session-token").
+ SetIntent("bind_current_user").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("bind-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Bound Example",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/settings/profile",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Bound Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("bind-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"])
+ _, hasDisplayName := identity.Metadata["display_name"]
+ require.False(t, hasDisplayName)
+ _, hasAvatarURL := identity.Metadata["avatar_url"]
+ require.False(t, hasAvatarURL)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("bind-conflict-target@example.com").
+ SetUsername("target-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ ownerUser, err := client.User.Create().
+ SetEmail("bind-conflict-owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ existingIdentity, err := client.AuthIdentity.Create().
+ SetUserID(ownerUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("conflict-123").
+ SetMetadata(map[string]any{"username": "owner-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-conflict-session-token").
+ SetIntent("bind_current_user").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("conflict-123").
+ SetTargetUserID(targetUser.ID).
+ SetResolvedEmail(targetUser.Email).
+ SetBrowserSessionKey("bind-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Conflict Example",
+ "suggested_avatar_url": "https://cdn.example/conflict.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"])
+
+ identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, ownerUser.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-false@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-false-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-false-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-false-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Login Example",
+ "suggested_avatar_url": "https://cdn.example/login.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-false-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginReassignsExistingDecisionIdentityReference(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-reassign@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ existingIdentity, err := client.AuthIdentity.Create().
+ SetUserID(userEntity.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previousSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-reassign-previous-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-reassign-previous-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "previous-access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previousDecision, err := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(previousSession.ID).
+ SetIdentityID(existingIdentity.ID).
+ SetAdoptDisplayName(true).
+ SetAdoptAvatar(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-reassign-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-reassign-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Login Reassign",
+ "suggested_avatar_url": "https://cdn.example/login-reassign.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(session.ID).
+ SetAdoptDisplayName(false).
+ SetAdoptAvatar(false).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-reassign-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ reloadedPrevious, err := client.IdentityAdoptionDecision.Get(ctx, previousDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedPrevious.IdentityID)
+
+ currentDecision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, currentDecision.IdentityID)
+ require.Equal(t, existingIdentity.ID, *currentDecision.IdentityID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-nodecision@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-nodecision-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-nodecision-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-nodecision-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "login-nodecision-user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-nodecision-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("existing-login@example.com").
+ SetUsername("existing-login-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(userEntity.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("existing-login-123").
+ SetMetadata(map[string]any{
+ "username": "existing-login-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-login-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("existing-login-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("existing-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Existing Login Example",
+ "suggested_avatar_url": "https://cdn.example/existing-login.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ payload := decodeJSONResponseData(t, recorder)
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.NotEqual(t, "legacy-access-token", payload["access_token"])
+ require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
+ require.NotContains(t, payload, "adoption_required")
+
+ accessToken, ok := payload["access_token"].(string)
+ require.True(t, ok)
+ claims, err := handler.authService.ValidateToken(accessToken)
+ require.NoError(t, err)
+ reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
+
+ decisionCount, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, decisionCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+
+ completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.NotContains(t, completion, "access_token")
+ require.NotContains(t, completion, "refresh_token")
+ require.NotContains(t, completion, "expires_in")
+ require.NotContains(t, completion, "token_type")
+ require.Equal(t, "/dashboard", completion["redirect"])
+}
+
+func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("blocked@example.com").
+ SetUsername("blocked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("blocked-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("blocked-subject-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("blocked-backend-mode-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "refresh_token": "refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("disabled-linked@example.com").
+ SetUsername("disabled-linked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("disabled-linked-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("disabled-linked-subject").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("disabled-linked-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Disabled Linked User",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
+ payload := normalizePendingOAuthCompletionResponse(map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ })
+
+ require.NotContains(t, payload, "access_token")
+ require.NotContains(t, payload, "refresh_token")
+ require.NotContains(t, payload, "expires_in")
+ require.NotContains(t, payload, "token_type")
+ require.Equal(t, "/dashboard", payload["redirect"])
+}
+
+func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, true)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("invitation-required-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invitation-123").
+ SetBrowserSessionKey("invitation-required-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Invite Example",
+ "suggested_avatar_url": "https://cdn.example/invite.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "error": "invitation_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ data := decodeJSONResponseData(t, recorder)
+ require.Equal(t, "invitation_required", data["error"])
+ require.Equal(t, true, data["adoption_required"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invitation-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-123").
+ SetBrowserSessionKey("create-account-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Fresh OIDC User",
+ "suggested_avatar_url": "https://cdn.example/fresh.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusActive, createdUser.Status)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-create-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, createdUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-123").
+ SetBrowserSessionKey("existing-email-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, oauthIntentLogin, payload["intent"])
+ require.Equal(t, "oidc", payload["provider"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, true, payload["adoption_required"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+ require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+ require.Nil(t, storedSession.ConsumedAt)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-existing-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-normalized-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-normalized-123").
+ SetBrowserSessionKey("existing-email-normalized-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, oauthIntentLogin, payload["intent"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
+func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-send-code-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-send-code-123").
+ SetBrowserSessionKey("existing-email-send-code-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "email_required",
+ },
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.SendPendingOAuthVerifyCode(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
+func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-backend-mode-123").
+ SetBrowserSessionKey("create-account-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
+ handler, _ := newOAuthPendingFlowTestHandler(t, false)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
+ ginCtx.Request = req
+
+ handler.Logout(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
+}
+
+func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ conflictOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(conflictOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetMetadata(map[string]any{
+ "username": "owner-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ invitation, err := client.RedeemCode.Create().
+ SetCode("INVITE123").
+ SetType(service.RedeemTypeInvitation).
+ SetStatus(service.StatusUnused).
+ SetValue(0).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-conflict-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetBrowserSessionKey("create-account-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusUnused, storedInvitation.Status)
+ require.Nil(t, storedInvitation.UsedBy)
+ require.Nil(t, storedInvitation.UsedAt)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ userRepoOptions: oauthPendingFlowUserRepoOptions{
+ rejectDeleteWhileAuthIdentityExists: true,
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-finalize-failure-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-finalize-failure-123").
+ SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
+ return errors.New("forced post-bind failure")
+ }
+ t.Cleanup(func() {
+ pendingOAuthCreateAccountPreCommitHook = nil
+ })
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-backend-mode-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-backend-mode-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-invalid-password-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-invalid-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusUnauthorized, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginReclaimsIdentityOwnedBySoftDeletedUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ oldOwnerHash, err := handler.authService.HashPassword("old-secret")
+ require.NoError(t, err)
+ oldOwner, err := client.User.Create().
+ SetEmail("old-owner@example.com").
+ SetUsername("old-owner").
+ SetPasswordHash(oldOwnerHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(oldOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-soft-deleted-123").
+ SetMetadata(map[string]any{"username": "old-owner"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.User.Delete().Where(dbuser.IDEQ(oldOwner.ID)).Exec(ctx)
+ require.NoError(t, err)
+
+ newOwnerHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ newOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(newOwnerHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-soft-deleted-owner-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-soft-deleted-123").
+ SetTargetUserID(newOwner.ID).
+ SetResolvedEmail(newOwner.Email).
+ SetBrowserSessionKey("bind-login-soft-deleted-owner-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Recovered OIDC User",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-soft-deleted-owner-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err = client.AuthIdentity.Get(ctx, identity.ID)
+ require.NoError(t, err)
+ require.Equal(t, newOwner.ID, identity.UserID)
+}
+
+func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) {
+ defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5",
+ service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3",
+ service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
+ },
+ defaultSubAssigner: defaultSubAssigner,
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("first-bind-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-first-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("first-bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ firstRecorder := httptest.NewRecorder()
+ firstGinCtx, _ := gin.CreateTestContext(firstRecorder)
+ firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody)
+ firstReq.Header.Set("Content-Type", "application/json")
+ firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)})
+ firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")})
+ firstGinCtx.Request = firstReq
+
+ handler.BindOIDCOAuthLogin(firstGinCtx)
+
+ require.Equal(t, http.StatusOK, firstRecorder.Code)
+
+ storedUser, err := client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 17.5, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.Zero(t, storedUser.TotalRecharged)
+ require.Len(t, defaultSubAssigner.calls, 1)
+ require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID)
+ require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID)
+ require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+
+ secondSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("second-bind-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-second-456").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("second-bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Second OIDC User",
+ "suggested_avatar_url": "https://cdn.example/second.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ secondRecorder := httptest.NewRecorder()
+ secondGinCtx, _ := gin.CreateTestContext(secondRecorder)
+ secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody)
+ secondReq.Header.Set("Content-Type", "application/json")
+ secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)})
+ secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")})
+ secondGinCtx.Request = secondReq
+
+ handler.BindOIDCOAuthLogin(secondGinCtx)
+
+ require.Equal(t, http.StatusOK, secondRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 17.5, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.Zero(t, storedUser.TotalRecharged)
+ require.Len(t, defaultSubAssigner.calls, 1)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+}
+
+func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ _ = handler
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("resolve-target-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-target-123").
+ SetResolvedEmail("owner@example.com").
+ SetBrowserSessionKey("resolve-target-browser-session-key").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, resolvedUserID)
+}
+
+func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) {
+ totpCache := &oauthPendingFlowTotpCacheStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyTotpEnabled: "true",
+ },
+ totpCache: totpCache,
+ totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ totpEnabledAt := time.Now().UTC().Add(-time.Hour)
+ secret := "JBSWY3DPEHPK3PXP"
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetTotpEnabled(true).
+ SetTotpSecretEncrypted(secret).
+ SetTotpEnabledAt(totpEnabledAt).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-2fa-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-2fa-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-2fa-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ data := decodeJSONResponseData(t, recorder)
+ require.Equal(t, true, data["requires_2fa"])
+ require.Equal(t, "o***r@example.com", data["user_email_masked"])
+ tempToken, ok := data["temp_token"].(string)
+ require.True(t, ok)
+ require.NotEmpty(t, tempToken)
+
+ loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
+ require.NoError(t, err)
+ require.NotNil(t, loginSession)
+ require.NotNil(t, loginSession.PendingOAuthBind)
+ require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken)
+ require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
+ totpCache := &oauthPendingFlowTotpCacheStub{}
+ defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyTotpEnabled: "true",
+ service.SettingKeyAuthSourceDefaultOIDCBalance: "8",
+ service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2",
+ service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
+ },
+ defaultSubAssigner: defaultSubAssigner,
+ totpCache: totpCache,
+ totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ totpEnabledAt := time.Now().UTC().Add(-time.Hour)
+ secret := "JBSWY3DPEHPK3PXP"
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(4).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetTotpEnabled(true).
+ SetTotpSecretEncrypted(secret).
+ SetTotpEnabledAt(totpEnabledAt).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-2fa-pending-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-login-2fa-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("login-2fa-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(session.ID).
+ SetAdoptDisplayName(false).
+ SetAdoptAvatar(false).
+ Save(ctx)
+ require.NoError(t, err)
+
+ tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession(
+ ctx,
+ existingUser.ID,
+ existingUser.Email,
+ session.SessionToken,
+ session.BrowserSessionKey,
+ )
+ require.NoError(t, err)
+
+ code, err := totp.GenerateCode(secret, time.Now().UTC())
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)})
+ ginCtx.Request = req
+
+ handler.Login2FA(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ payload := decodeJSONResponseData(t, recorder)
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ accessToken, ok := payload["access_token"].(string)
+ require.True(t, ok)
+ claims, err := handler.authService.ValidateToken(accessToken)
+ require.NoError(t, err)
+ reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-login-2fa-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+
+ loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
+ require.NoError(t, err)
+ require.Nil(t, loginSession)
+
+ storedUser, err := client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 9.5, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+ require.Empty(t, defaultSubAssigner.calls)
+}
+
+func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
+}
+
+func newOAuthPendingFlowTestHandlerWithEmailVerification(
+ t *testing.T,
+ invitationEnabled bool,
+ email string,
+ code string,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ cache := &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ email: {
+ Code: code,
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ }
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
+}
+
+func newOAuthPendingFlowTestHandlerWithOptions(
+ t *testing.T,
+ invitationEnabled bool,
+ emailVerifyEnabled bool,
+ emailCache service.EmailCache,
+) (*AuthHandler, *dbent.Client) {
+ return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ invitationEnabled: invitationEnabled,
+ emailVerifyEnabled: emailVerifyEnabled,
+ emailCache: emailCache,
+ })
+}
+
+type oauthPendingFlowTestHandlerOptions struct {
+ invitationEnabled bool
+ emailVerifyEnabled bool
+ emailCache service.EmailCache
+ settingValues map[string]string
+ defaultSubAssigner service.DefaultSubscriptionAssigner
+ totpCache service.TotpCache
+ totpEncryptor service.SecretEncryptor
+ userRepoOptions oauthPendingFlowUserRepoOptions
+}
+
+func newOAuthPendingFlowTestHandlerWithDependencies(
+ t *testing.T,
+ options oauthPendingFlowTestHandlerOptions,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_avatars (
+ user_id INTEGER PRIMARY KEY,
+ storage_provider TEXT NOT NULL,
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL,
+ content_type TEXT NOT NULL DEFAULT '',
+ byte_size INTEGER NOT NULL DEFAULT 0,
+ sha256 TEXT NOT NULL DEFAULT '',
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ settingValues := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
+ }
+ for key, value := range options.settingValues {
+ settingValues[key] = value
+ }
+ settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
+ userRepo := &oauthPendingFlowUserRepo{
+ client: client,
+ options: options.userRepoOptions,
+ }
+ redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
+ var emailService *service.EmailService
+ if options.emailCache != nil {
+ emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
+ },
+ }, options.emailCache)
+ }
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &oauthPendingFlowRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ emailService,
+ nil,
+ nil,
+ nil,
+ options.defaultSubAssigner,
+ )
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ var totpSvc *service.TotpService
+ if options.totpCache != nil || options.totpEncryptor != nil {
+ totpCache := options.totpCache
+ if totpCache == nil {
+ totpCache = &oauthPendingFlowTotpCacheStub{}
+ }
+ totpEncryptor := options.totpEncryptor
+ if totpEncryptor == nil {
+ totpEncryptor = oauthPendingFlowTotpEncryptorStub{}
+ }
+ totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil)
+ }
+
+ return &AuthHandler{
+ authService: authSvc,
+ userService: userSvc,
+ settingSvc: settingSvc,
+ totpService: totpSvc,
+ }, client
+}
+
+func boolSettingValue(v bool) string {
+ if v {
+ return "true"
+ }
+ return "false"
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+type oauthPendingFlowSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type oauthPendingFlowRefreshTokenCacheStub struct{}
+
+type oauthPendingFlowEmailCacheStub struct {
+ verificationCodes map[string]*service.VerificationCodeData
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
+ if s == nil || s.verificationCodes == nil {
+ return nil, nil
+ }
+ return s.verificationCodes[email], nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
+ if s.verificationCodes == nil {
+ s.verificationCodes = map[string]*service.VerificationCodeData{}
+ }
+ s.verificationCodes[email] = data
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
+ delete(s.verificationCodes, email)
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+type oauthPendingFlowRedeemCodeRepo struct {
+ client *dbent.Client
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
+ entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ notes := ""
+ if entity.Notes != nil {
+ notes = *entity.Notes
+ }
+ return &service.RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: notes,
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ update := r.client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
+ affected, err := r.client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
+ SetStatus(service.StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrRedeemCodeUsed
+ }
+ return nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var envelope struct {
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
+ return envelope.Data
+}
+
+func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ return payload
+}
+
+type oauthPendingFlowAvatarRecord struct {
+ StorageProvider string
+ URL string
+}
+
+func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ require.NoError(t, rows.Err())
+ return nil
+ }
+
+ var record oauthPendingFlowAvatarRecord
+ require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL))
+ require.NoError(t, rows.Err())
+ return &record
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ []any{userID, providerType, grantReason},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next())
+ var count int
+ require.NoError(t, rows.Scan(&count))
+ require.False(t, rows.Next())
+ return count
+}
+
+type oauthPendingFlowUserRepo struct {
+ client *dbent.Client
+ options oauthPendingFlowUserRepoOptions
+}
+
+type oauthPendingFlowUserRepoOptions struct {
+ rejectDeleteWhileAuthIdentityExists bool
+}
+
+func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.Create().
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
+ SetTotpEnabled(user.TotpEnabled).
+ SetNillableTotpEnabledAt(user.TotpEnabledAt).
+ SetTotalRecharged(user.TotalRecharged).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.ID = entity.ID
+ user.CreatedAt = entity.CreatedAt
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
+ entity, err := r.client.User.Get(ctx, id)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
+ entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.UpdateOneID(user.ID).
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
+ SetTotpEnabled(user.TotpEnabled).
+ SetNillableTotpEnabledAt(user.TotpEnabledAt).
+ SetTotalRecharged(user.TotalRecharged).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return r.client.User.UpdateOneID(userID).SetLastActiveAt(activeAt).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
+ if r.options.rejectDeleteWhileAuthIdentityExists {
+ count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
+ if err != nil {
+ return err
+ }
+ if count > 0 {
+ return errors.New("cannot delete user while auth identities still exist")
+ }
+ }
+ return r.client.User.DeleteOneID(id).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var rows entsql.Rows
+ if err := driver.Query(
+ ctx,
+ `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ ); err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ if err := driver.Exec(
+ ctx,
+ `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
+ON CONFLICT(user_id) DO UPDATE SET
+ storage_provider = excluded.storage_provider,
+ storage_key = excluded.storage_key,
+ url = excluded.url,
+ content_type = excluded.content_type,
+ byte_size = excluded.byte_size,
+ sha256 = excluded.sha256,
+ updated_at = CURRENT_TIMESTAMP`,
+ []any{
+ userID,
+ input.StorageProvider,
+ input.StorageKey,
+ input.URL,
+ input.ContentType,
+ input.ByteSize,
+ input.SHA256,
+ },
+ &result,
+ ); err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result)
+}
+
+func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
+ panic("unexpected UpdateBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
+ panic("unexpected DeductBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
+ panic("unexpected UpdateConcurrency call")
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
+ count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
+ return count > 0, err
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected RemoveGroupFromAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected RemoveGroupFromUserAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ identities, err := r.client.AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(userID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ records := make([]service.UserAuthIdentityRecord, 0, len(identities))
+ for _, identity := range identities {
+ if identity == nil {
+ continue
+ }
+ records = append(records, service.UserAuthIdentityRecord{
+ ProviderType: identity.ProviderType,
+ ProviderKey: identity.ProviderKey,
+ ProviderSubject: identity.ProviderSubject,
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: identity.Metadata,
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ })
+ }
+ return records, nil
+}
+
+func (r *oauthPendingFlowUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ update := r.client.User.UpdateOneID(userID)
+ if encryptedSecret == nil {
+ update = update.ClearTotpSecretEncrypted()
+ } else {
+ update = update.SetTotpSecretEncrypted(*encryptedSecret)
+ }
+ return update.Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error {
+ return r.client.User.UpdateOneID(userID).
+ SetTotpEnabled(true).
+ SetTotpEnabledAt(time.Now().UTC()).
+ Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error {
+ return r.client.User.UpdateOneID(userID).
+ SetTotpEnabled(false).
+ ClearTotpSecretEncrypted().
+ ClearTotpEnabledAt().
+ Exec(ctx)
+}
+
+func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
+ if entity == nil {
+ return nil
+ }
+ return &service.User{
+ ID: entity.ID,
+ Email: entity.Email,
+ Username: entity.Username,
+ Notes: entity.Notes,
+ PasswordHash: entity.PasswordHash,
+ Role: entity.Role,
+ Balance: entity.Balance,
+ Concurrency: entity.Concurrency,
+ Status: entity.Status,
+ SignupSource: entity.SignupSource,
+ LastLoginAt: entity.LastLoginAt,
+ LastActiveAt: entity.LastActiveAt,
+ TotpSecretEncrypted: entity.TotpSecretEncrypted,
+ TotpEnabled: entity.TotpEnabled,
+ TotpEnabledAt: entity.TotpEnabledAt,
+ TotalRecharged: entity.TotalRecharged,
+ CreatedAt: entity.CreatedAt,
+ UpdatedAt: entity.UpdatedAt,
+ }
+}
+
+type oauthPendingFlowDefaultSubAssignerStub struct {
+ calls []service.AssignSubscriptionInput
+}
+
+func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ if input != nil {
+ s.calls = append(s.calls, *input)
+ }
+ return nil, false, nil
+}
+
+type oauthPendingFlowTotpCacheStub struct {
+ setupSessions map[int64]*service.TotpSetupSession
+ loginSessions map[string]*service.TotpLoginSession
+ verifyAttempts map[int64]int
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) {
+ if s == nil || s.setupSessions == nil {
+ return nil, nil
+ }
+ return s.setupSessions[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error {
+ if s.setupSessions == nil {
+ s.setupSessions = map[int64]*service.TotpSetupSession{}
+ }
+ s.setupSessions[userID] = session
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error {
+ delete(s.setupSessions, userID)
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) {
+ if s == nil || s.loginSessions == nil {
+ return nil, nil
+ }
+ return s.loginSessions[tempToken], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error {
+ if s.loginSessions == nil {
+ s.loginSessions = map[string]*service.TotpLoginSession{}
+ }
+ s.loginSessions[tempToken] = session
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error {
+ delete(s.loginSessions, tempToken)
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) {
+ if s.verifyAttempts == nil {
+ s.verifyAttempts = map[int64]int{}
+ }
+ s.verifyAttempts[userID]++
+ return s.verifyAttempts[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) {
+ if s == nil || s.verifyAttempts == nil {
+ return 0, nil
+ }
+ return s.verifyAttempts[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error {
+ delete(s.verifyAttempts, userID)
+ return nil
+}
+
+type oauthPendingFlowTotpEncryptorStub struct{}
+
+func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) {
+ return plaintext, nil
+}
+
+func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) {
+ return ciphertext, nil
+}
diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go
new file mode 100644
index 00000000..47bad942
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_test_helpers_test.go
@@ -0,0 +1,57 @@
+package handler
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func buildEncodedOAuthBindUserCookie(t *testing.T, userID int64, secret string) string {
+ t.Helper()
+ value, err := buildOAuthBindUserCookieValue(userID, secret)
+ require.NoError(t, err)
+ return value
+}
+
+func encodedCookie(name, value string) *http.Cookie {
+ return &http.Cookie{
+ Name: name,
+ Value: encodeCookieValue(value),
+ Path: "/",
+ }
+}
+
+func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
+ for _, cookie := range cookies {
+ if cookie.Name == name {
+ return cookie
+ }
+ }
+ return nil
+}
+
+func decodeCookieValueForTest(t *testing.T, value string) string {
+ t.Helper()
+ decoded, err := decodeCookieValue(value)
+ require.NoError(t, err)
+ return decoded
+}
+
+func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
+ t.Helper()
+ require.NotEmpty(t, location)
+
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+
+ rawValues := parsed.RawQuery
+ if rawValues == "" {
+ rawValues = parsed.Fragment
+ }
+ values, err := url.ParseQuery(rawValues)
+ require.NoError(t, err)
+ require.Equal(t, errorCode, values.Get("error"))
+ require.Equal(t, errorMessage, values.Get("error_message"))
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 9d24df88..0ac8871b 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -19,6 +19,7 @@ import (
"strings"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -32,14 +33,16 @@ import (
)
const (
- oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
- oidcOAuthStateCookieName = "oidc_oauth_state"
- oidcOAuthVerifierCookie = "oidc_oauth_verifier"
- oidcOAuthRedirectCookie = "oidc_oauth_redirect"
- oidcOAuthNonceCookie = "oidc_oauth_nonce"
- oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
- oidcOAuthDefaultRedirectTo = "/dashboard"
- oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
+ oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
+ oidcOAuthStateCookieName = "oidc_oauth_state"
+ oidcOAuthVerifierCookie = "oidc_oauth_verifier"
+ oidcOAuthRedirectCookie = "oidc_oauth_redirect"
+ oidcOAuthNonceCookie = "oidc_oauth_nonce"
+ oidcOAuthIntentCookieName = "oidc_oauth_intent"
+ oidcOAuthBindUserCookieName = "oidc_oauth_bind_user"
+ oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
+ oidcOAuthDefaultRedirectTo = "/dashboard"
+ oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
)
type oidcTokenResponse struct {
@@ -87,6 +90,8 @@ type oidcUserInfoClaims struct {
Username string
Subject string
EmailVerified *bool
+ DisplayName string
+ AvatarURL string
}
type oidcJWKSet struct {
@@ -127,9 +132,29 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ oidcSetCookie(c, oidcOAuthIntentCookieName, encodeCookieValue(intent), oidcOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ oidcSetCookie(c, oidcOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), oidcOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+ }
codeChallenge := ""
if cfg.UsePKCE {
@@ -199,6 +224,8 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName)
@@ -212,6 +239,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+ intent, _ := readCookieDecoded(c, oidcOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
@@ -258,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
- if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" {
- redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
- return
- }
+ var idClaims *oidcIDTokenClaims
+ if cfg.ValidateIDToken {
+ if strings.TrimSpace(tokenResp.IDToken) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
+ return
+ }
- idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
- if err != nil {
- log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
- redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
- return
+ idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
+ if err != nil {
+ log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
+ redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
+ return
+ }
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
@@ -277,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
- subject := strings.TrimSpace(idClaims.Subject)
+ subject := ""
+ if idClaims != nil {
+ subject = strings.TrimSpace(idClaims.Subject)
+ }
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
@@ -285,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
- issuer := strings.TrimSpace(idClaims.Issuer)
+ issuer := ""
+ if idClaims != nil {
+ issuer = strings.TrimSpace(idClaims.Issuer)
+ }
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
@@ -295,9 +338,115 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
emailVerified := userInfoClaims.EmailVerified
- if emailVerified == nil {
+ if emailVerified == nil && idClaims != nil {
emailVerified = idClaims.EmailVerified
}
+ if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
+ redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
+ return
+ }
+
+ identityKey := oidcIdentityKey(issuer, subject)
+ compatEmail := strings.TrimSpace(userInfoClaims.Email)
+ if compatEmail == "" && idClaims != nil {
+ compatEmail = strings.TrimSpace(idClaims.Email)
+ }
+ email := oidcSyntheticEmailFromIdentityKey(identityKey)
+ username := firstNonEmpty(
+ userInfoClaims.Username,
+ func() string {
+ if idClaims != nil {
+ return idClaims.PreferredUsername
+ }
+ return ""
+ }(),
+ func() string {
+ if idClaims != nil {
+ return idClaims.Name
+ }
+ return ""
+ }(),
+ oidcFallbackUsername(subject),
+ )
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
+ if idClaims != nil {
+ return idClaims.Name
+ }
+ return ""
+ }(), username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
+ }
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityRef,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityRef,
+ TargetUserID: &existingIdentityUser.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
if cfg.RequireEmailVerified {
if emailVerified == nil || !*emailVerified {
redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
@@ -305,47 +454,136 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
}
- identityKey := oidcIdentityKey(issuer, subject)
- email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey)
- username := firstNonEmpty(
- userInfoClaims.Username,
- idClaims.PreferredUsername,
- idClaims.Name,
- oidcFallbackUsername(subject),
- )
-
- // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
- if err != nil {
- if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
- return
- }
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOIDCOAuthChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ true,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
- redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ redirectToFrontendCallback(c, frontendCallback)
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if err := h.createOIDCOAuthChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ h.isForceEmailOnThirdPartySignup(c.Request.Context()),
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
}
type completeOIDCOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@@ -358,17 +596,87 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
-
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -405,7 +713,7 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
@@ -560,9 +868,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC
if verified, ok := getGJSONBool(body, "email_verified"); ok {
claims.EmailVerified = &verified
}
+ claims.DisplayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "preferred_username"),
+ getGJSON(body, "username"),
+ )
+ claims.AvatarURL = firstNonEmpty(
+ getGJSON(body, "picture"),
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
claims.Email = strings.TrimSpace(claims.Email)
claims.Username = strings.TrimSpace(claims.Username)
claims.Subject = strings.TrimSpace(claims.Subject)
+ claims.DisplayName = strings.TrimSpace(claims.DisplayName)
+ claims.AvatarURL = strings.TrimSpace(claims.AvatarURL)
return claims
}
@@ -595,7 +920,7 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
@@ -831,14 +1156,6 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
}
-func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string {
- email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail))
- if email != "" {
- return email
- }
- return oidcSyntheticEmailFromIdentityKey(identityKey)
-}
-
func oidcFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index a161aa77..3216d51e 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -1,6 +1,7 @@
package handler
import (
+ "bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -12,7 +13,15 @@ import (
"testing"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
@@ -30,26 +39,11 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
require.Contains(t, e1, "@oidc-connect.invalid")
}
-func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) {
- identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a")
-
- email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey)
- require.Equal(t, "user@example.com", email)
-
- email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey)
- require.Equal(t, "idtoken@example.com", email)
-
- email = oidcSelectLoginEmail("", "", identityKey)
- require.Contains(t, email, "@oidc-connect.invalid")
- require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email)
-}
-
func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
cfg := config.OIDCConnectConfig{
AuthorizeURL: "https://issuer.example.com/auth",
ClientID: "cid",
Scopes: "openid email profile",
- UsePKCE: true,
}
u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
@@ -106,6 +100,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) {
require.Error(t, err)
}
+func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) {
+ cfg := config.OIDCConnectConfig{}
+
+ claims := oidcParseUserInfo(`{
+ "sub":"subject-1",
+ "preferred_username":"alice",
+ "name":"Alice Example",
+ "picture":"https://cdn.example/avatar.png",
+ "email":"alice@example.com",
+ "email_verified":true
+ }`, cfg)
+
+ require.Equal(t, "subject-1", claims.Subject)
+ require.Equal(t, "alice", claims.Username)
+ require.Equal(t, "Alice Example", claims.DisplayName)
+ require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL)
+ require.NotNil(t, claims.EmailVerified)
+ require.True(t, *claims.EmailVerified)
+}
+
func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
@@ -118,3 +132,909 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
E: e,
}
}
+
+func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
+ handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/oauth/authorize",
+ TokenURL: "https://issuer.example.com/oauth/token",
+ UserInfoURL: "https://issuer.example.com/oauth/userinfo",
+ JWKSURL: "https://issuer.example.com/oauth/jwks",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ RequireEmailVerified: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ c.Request = req
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 84})
+
+ handler.OIDCOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.Contains(t, location, "issuer.example.com/oauth/authorize")
+ require.Contains(t, location, "client_id=oidc-client")
+ require.Contains(t, location, "nonce=")
+
+ cookies := recorder.Result().Cookies()
+ require.NotNil(t, findCookie(cookies, oidcOAuthStateCookieName))
+ require.NotNil(t, findCookie(cookies, oidcOAuthRedirectCookie))
+ require.NotNil(t, findCookie(cookies, oidcOAuthVerifierCookie))
+ require.NotNil(t, findCookie(cookies, oidcOAuthNonceCookie))
+ require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
+
+ intentCookie := findCookie(cookies, oidcOAuthIntentCookieName)
+ require.NotNil(t, intentCookie)
+ require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
+
+ bindCookie := findCookie(cookies, oidcOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, int64(84), userID)
+}
+
+func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
+ handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/oauth/authorize",
+ TokenURL: "https://issuer.example.com/oauth/token",
+ UserInfoURL: "https://issuer.example.com/oauth/userinfo",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ ValidateIDToken: false,
+ RequireEmailVerified: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
+
+ handler.OIDCOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotContains(t, location, "code_challenge=")
+ require.NotContains(t, location, "nonce=")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
+}
+
+func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, r.ParseForm())
+ require.Empty(t, r.PostForm.Get("code_verifier"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ ValidateIDToken: false,
+ RequireEmailVerified: false,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+ require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+}
+
+func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-login",
+ PreferredUsername: "oidc_login",
+ DisplayName: "OIDC Login Display",
+ AvatarURL: "https://cdn.example/oidc-login.png",
+ Email: "oidc-login@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-subject-login"))).
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("oidc").
+ SetProviderKey(cfg.IssuerURL).
+ SetProviderSubject("oidc-subject-login").
+ SetMetadata(map[string]any{"username": "legacy-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-123"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-login"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, cfg.IssuerURL, session.ProviderKey)
+ require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+ require.Nil(t, completion["error"])
+}
+
+func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-disabled-subject",
+ PreferredUsername: "oidc_disabled",
+ DisplayName: "OIDC Disabled",
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("oidc").
+ SetProviderKey(cfg.IssuerURL).
+ SetProviderSubject("oidc-disabled-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-compat",
+ PreferredUsername: "oidc_compat",
+ DisplayName: "OIDC Compat Display",
+ AvatarURL: "https://cdn.example/oidc-compat.png",
+ Email: "legacy@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+ require.Equal(t, existingUser.Email, completion["existing_account_email"])
+ require.Equal(t, true, completion["existing_account_bindable"])
+ require.Equal(t, "compat_email_match", completion["choice_reason"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-unverified-compat",
+ PreferredUsername: "oidc_unverified",
+ DisplayName: "OIDC Unverified Compat Display",
+ AvatarURL: "https://cdn.example/oidc-unverified.png",
+ Email: "owner@example.com",
+ EmailVerified: false,
+ })
+ defer cleanup()
+ cfg.RequireEmailVerified = true
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ _, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback#error=email_not_verified&error_message=email+is+not+verified", recorder.Header().Get("Location"))
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestOIDCOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-invite",
+ PreferredUsername: "oidc_invite",
+ DisplayName: "OIDC Invite Display",
+ AvatarURL: "https://cdn.example/oidc-invite.png",
+ Email: "oidc-invite@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-456", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-456"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-456"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-invite"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Nil(t, session.TargetUserID)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-bind",
+ PreferredUsername: "oidc_bind",
+ DisplayName: "OIDC Bind Display",
+ AvatarURL: "https://cdn.example/oidc-bind.png",
+ Email: "oidc-bind@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-bind", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentBindCurrentUser))
+ req.AddCookie(encodedCookie(oidcOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, cfg.IssuerURL, session.ProviderKey)
+ require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Empty(t, completion["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, userCount)
+}
+
+func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-1").
+ SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Display",
+ "suggested_avatar_url": "https://cdn.example/oidc.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "OIDC Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("oidc-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-choice-subject-1").
+ SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-no-adoption").
+ SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Legacy",
+ "suggested_avatar_url": "https://cdn.example/oidc-legacy.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "oidc_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-conflict-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-conflict-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-conflict-subject").
+ SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-conflict-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
+
+ userCount, err := client.User.Query().
+ Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+type oidcProviderFixture struct {
+ Subject string
+ PreferredUsername string
+ DisplayName string
+ AvatarURL string
+ Email string
+ EmailVerified bool
+}
+
+func newOIDCOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) *AuthHandler {
+ t.Helper()
+ handler, _ := newOIDCOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
+ return handler
+}
+
+func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+ handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
+ handler.settingSvc = nil
+ handler.cfg = &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ OIDC: oauthCfg,
+ }
+ return handler, client
+}
+
+func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) {
+ t.Helper()
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ kid := "test-kid"
+ jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &privateKey.PublicKey)}}
+ tokenResponse := oidcTokenResponse{
+ AccessToken: "oidc-access-token",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ }
+
+ userInfoPayload := map[string]any{
+ "sub": fixture.Subject,
+ "preferred_username": fixture.PreferredUsername,
+ "name": fixture.DisplayName,
+ "picture": fixture.AvatarURL,
+ "email": fixture.Email,
+ "email_verified": fixture.EmailVerified,
+ }
+
+ var issuer string
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, json.NewEncoder(w).Encode(tokenResponse))
+ case "/userinfo":
+ require.NoError(t, json.NewEncoder(w).Encode(userInfoPayload))
+ case "/jwks":
+ require.NoError(t, json.NewEncoder(w).Encode(jwks))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+
+ issuer = server.URL
+ now := time.Now()
+ claims := oidcIDTokenClaims{
+ Email: fixture.Email,
+ EmailVerified: boolPtr(fixture.EmailVerified),
+ PreferredUsername: fixture.PreferredUsername,
+ Name: fixture.DisplayName,
+ Nonce: "nonce-" + fixture.Subject,
+ RegisteredClaims: jwt.RegisteredClaims{
+ Issuer: issuer,
+ Subject: fixture.Subject,
+ Audience: jwt.ClaimStrings{"oidc-client"},
+ IssuedAt: jwt.NewNumericDate(now),
+ NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)),
+ ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
+ },
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ token.Header["kid"] = kid
+ tokenResponse.IDToken, err = token.SignedString(privateKey)
+ require.NoError(t, err)
+
+ cfg := config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "Test OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: issuer,
+ AuthorizeURL: issuer + "/authorize",
+ TokenURL: issuer + "/token",
+ UserInfoURL: issuer + "/userinfo",
+ JWKSURL: issuer + "/jwks",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ RequireEmailVerified: false,
+ }
+ return cfg, server.Close
+}
diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go
new file mode 100644
index 00000000..1924cb81
--- /dev/null
+++ b/backend/internal/handler/auth_session_revocation_test.go
@@ -0,0 +1,61 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 29,
+ Email: "session@example.com",
+ Username: "session-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 7,
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
+ handler := &AuthHandler{authService: authService}
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29})
+
+ handler.RevokeAllSessions(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(8), repo.user.TokenVersion)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Message string `json:"message"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message)
+}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
new file mode 100644
index 00000000..efee4cc0
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -0,0 +1,1349 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
+ wechatOAuthCookieMaxAgeSec = 10 * 60
+ wechatOAuthStateCookieName = "wechat_oauth_state"
+ wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
+ wechatOAuthIntentCookieName = "wechat_oauth_intent"
+ wechatOAuthModeCookieName = "wechat_oauth_mode"
+ wechatOAuthBindUserCookieName = "wechat_oauth_bind_user"
+ wechatOAuthDefaultRedirectTo = "/dashboard"
+ wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
+ wechatOAuthProviderKey = "wechat-main"
+ wechatOAuthLegacyProviderKey = "wechat"
+ wechatPaymentOAuthCookiePath = "/api/v1/auth/oauth/wechat/payment"
+ wechatPaymentOAuthStateName = "wechat_payment_oauth_state"
+ wechatPaymentOAuthRedirect = "wechat_payment_oauth_redirect"
+ wechatPaymentOAuthContextName = "wechat_payment_oauth_context"
+ wechatPaymentOAuthScope = "wechat_payment_oauth_scope"
+ wechatPaymentOAuthDefaultTo = "/purchase"
+ wechatPaymentOAuthFrontendCB = "/auth/wechat/payment/callback"
+
+ wechatOAuthIntentLogin = "login"
+ wechatOAuthIntentBind = "bind_current_user"
+ wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
+)
+
+var (
+ wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
+)
+
+type wechatOAuthConfig struct {
+ mode string
+ appID string
+ appSecret string
+ authorizeURL string
+ scope string
+ redirectURI string
+ frontendCallback string
+ openEnabled bool
+ mpEnabled bool
+}
+
+type wechatOAuthTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token"`
+ OpenID string `json:"openid"`
+ Scope string `json:"scope"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatOAuthUserInfoResponse struct {
+ OpenID string `json:"openid"`
+ Nickname string `json:"nickname"`
+ HeadImgURL string `json:"headimgurl"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatPaymentOAuthContext struct {
+ PaymentType string `json:"payment_type"`
+ Amount string `json:"amount,omitempty"`
+ OrderType string `json:"order_type,omitempty"`
+ PlanID int64 `json:"plan_id,omitempty"`
+}
+
+// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
+// browser cookies required by the rebuild pending-auth bridge.
+func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ intent := normalizeWeChatOAuthIntent(c.Query("intent"))
+ secureCookie := isRequestHTTPS(c)
+ wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ wechatSetCookie(c, wechatOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), wechatOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+ }
+
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
+// and stores the result in the unified pending-auth flow.
+func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
+ frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context())
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
+ redirectTo = sanitizeFrontendRedirectPath(redirectTo)
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+
+ intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
+ mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
+ if err != nil || strings.TrimSpace(mode) == "" {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
+ return
+ }
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
+ tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
+ return
+ }
+
+ unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
+ openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
+ providerSubject := unionid
+ if providerSubject == "" {
+ if cfg.requiresUnionID() {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
+ return
+ }
+ providerSubject = openid
+ }
+ if providerSubject == "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
+ return
+ }
+
+ username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
+ email := wechatSyntheticEmail(providerSubject)
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": providerSubject,
+ "openid": openid,
+ "unionid": unionid,
+ "mode": cfg.mode,
+ "channel": cfg.mode,
+ "channel_app_id": strings.TrimSpace(cfg.appID),
+ "channel_subject": openid,
+ "suggested_display_name": strings.TrimSpace(userInfo.Nickname),
+ "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
+ }
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ }
+
+ normalizedIntent := normalizeWeChatOAuthIntent(intent)
+ if normalizedIntent == wechatOAuthIntentBind {
+ if err := h.createWeChatBindPendingSession(c, cfg, providerSubject, openid, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ switch infraerrors.Code(err) {
+ case http.StatusConflict:
+ redirectOAuthError(c, frontendCallback, "ownership_conflict", infraerrors.Reason(err), infraerrors.Message(err))
+ case http.StatusUnauthorized, http.StatusForbidden:
+ redirectOAuthError(c, frontendCallback, "auth_required", infraerrors.Reason(err), infraerrors.Message(err))
+ default:
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ }
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser == nil {
+ existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ }
+ if existingIdentityUser != nil {
+ if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createWeChatChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ "",
+ nil,
+ true,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if err := h.createWeChatChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ "",
+ nil,
+ false,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+// WeChatPaymentOAuthStart starts the WeChat payment OAuth flow.
+// GET /api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay&redirect=/purchase
+func (h *AuthHandler) WeChatPaymentOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ paymentType := normalizeWeChatPaymentType(c.Query("payment_type"))
+ if paymentType == "" {
+ response.BadRequest(c, "Invalid payment type")
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(c.Query("redirect")))
+ if redirectTo == "" {
+ redirectTo = wechatPaymentOAuthDefaultTo
+ }
+ rawContext, err := encodeWeChatPaymentOAuthContext(wechatPaymentOAuthContext{
+ PaymentType: paymentType,
+ Amount: strings.TrimSpace(c.Query("amount")),
+ OrderType: strings.TrimSpace(c.Query("order_type")),
+ PlanID: parseWeChatPaymentPlanID(c.Query("plan_id")),
+ })
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONTEXT_ENCODE_FAILED", "failed to encode oauth context").WithCause(err))
+ return
+ }
+
+ scope := normalizeWeChatPaymentScope(c.Query("scope"))
+ secureCookie := isRequestHTTPS(c)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthStateName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthRedirect, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthContextName, encodeCookieValue(rawContext), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthScope, encodeCookieValue(scope), wechatOAuthCookieMaxAgeSec, secureCookie)
+
+ cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c)
+ cfg.scope = scope
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatPaymentOAuthCallback exchanges a payment OAuth code for an OpenID and
+// forwards the browser back to the frontend callback route.
+func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
+ frontendCallback := wechatPaymentOAuthFrontendCB
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatPaymentOAuthStateName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatPaymentOAuthRedirect)
+ redirectTo = normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(redirectTo))
+ if redirectTo == "" {
+ redirectTo = wechatPaymentOAuthDefaultTo
+ }
+
+ rawContext, _ := readCookieDecoded(c, wechatPaymentOAuthContextName)
+ paymentContext, err := decodeWeChatPaymentOAuthContext(rawContext)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "invalid oauth context", "")
+ return
+ }
+ if paymentContext.PaymentType == "" {
+ paymentContext.PaymentType = payment.TypeWxpay
+ }
+
+ scope, _ := readCookieDecoded(c, wechatPaymentOAuthScope)
+ scope = normalizeWeChatPaymentScope(scope)
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c)
+ tokenResp, err := exchangeWeChatOAuthCode(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", err.Error())
+ return
+ }
+
+ openid := strings.TrimSpace(tokenResp.OpenID)
+ if openid == "" {
+ redirectOAuthError(c, frontendCallback, "missing_openid", "missing openid", "")
+ return
+ }
+ if strings.TrimSpace(tokenResp.Scope) != "" {
+ scope = strings.TrimSpace(tokenResp.Scope)
+ }
+
+ resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{
+ OpenID: openid,
+ PaymentType: paymentContext.PaymentType,
+ Amount: paymentContext.Amount,
+ OrderType: paymentContext.OrderType,
+ PlanID: paymentContext.PlanID,
+ RedirectTo: redirectTo,
+ Scope: scope,
+ })
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "")
+ return
+ }
+
+ fragment := url.Values{}
+ fragment.Set("wechat_resume_token", resumeToken)
+ fragment.Set("redirect", redirectTo)
+ redirectWithFragment(c, frontendCallback, fragment)
+}
+
+func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
+ var legacyKey []byte
+ key, err := payment.ProvideEncryptionKey(h.cfg)
+ if err == nil {
+ legacyKey = []byte(key)
+ }
+ return service.NewLegacyAwarePaymentResumeService(legacyKey)
+}
+
+type completeWeChatOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
+// validating the invitation code and consuming the current pending browser session.
+// POST /api/v1/auth/oauth/wechat/complete-registration
+func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
+ var req completeWeChatOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) createWeChatPendingSession(
+ c *gin.Context,
+ intent string,
+ providerSubject string,
+ email string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ tokenPair *service.TokenPair,
+ authErr error,
+ targetUserID *int64,
+) error {
+ completionResponse := map[string]any{
+ "redirect": redirectTo,
+ }
+ if authErr != nil {
+ if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
+ completionResponse["error"] = "invitation_required"
+ } else {
+ return authErr
+ }
+ } else if tokenPair != nil {
+ completionResponse["access_token"] = tokenPair.AccessToken
+ completionResponse["refresh_token"] = tokenPair.RefreshToken
+ completionResponse["expires_in"] = tokenPair.ExpiresIn
+ completionResponse["token_type"] = "Bearer"
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: intent,
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ },
+ TargetUserID: targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) createWeChatChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ }
+ if forceEmailOnSignup {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) createWeChatBindPendingSession(
+ c *gin.Context,
+ cfg wechatOAuthConfig,
+ providerSubject string,
+ channelSubject string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+) error {
+ currentUser, err := h.readOAuthBindTargetUser(c, wechatOAuthBindUserCookieName)
+ if err != nil {
+ return err
+ }
+ if err := h.ensureWeChatBindOwnership(c.Request.Context(), currentUser.ID, providerSubject, cfg, channelSubject); err != nil {
+ return err
+ }
+ return h.createWeChatPendingSession(
+ c,
+ wechatOAuthIntentBind,
+ providerSubject,
+ currentUser.Email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ nil,
+ nil,
+ ¤tUser.ID,
+ )
+}
+
+func (h *AuthHandler) readOAuthBindTargetUser(c *gin.Context, cookieName string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ userID, err := h.readOAuthBindUserIDFromCookie(c, cookieName)
+ if err != nil {
+ return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account")
+ }
+ userEntity, err := client.User.Get(c.Request.Context(), userID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account")
+ }
+ return nil, infraerrors.InternalServer("WECHAT_BIND_USER_LOOKUP_FAILED", "failed to load current user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) ensureWeChatBindOwnership(
+ ctx context.Context,
+ userID int64,
+ providerSubject string,
+ cfg wechatOAuthConfig,
+ channelSubject string,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)),
+ ).
+ All(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err)
+ }
+ for _, identity := range identities {
+ if identity != nil && identity.UserID != userID {
+ activeOwner, lookupErr := findActiveUserByID(ctx, client, identity.UserID)
+ if lookupErr != nil {
+ return lookupErr
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ }
+
+ channelSubject = strings.TrimSpace(channelSubject)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if channelSubject == "" || channelAppID == "" {
+ return nil
+ }
+
+ channels, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err)
+ }
+ for _, channel := range channels {
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
+ activeOwner, lookupErr := findActiveUserByID(ctx, client, channel.Edges.Identity.UserID)
+ if lookupErr != nil {
+ return lookupErr
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ }
+ return nil
+}
+
+func (h *AuthHandler) findWeChatUserByLegacyOpenID(
+ ctx context.Context,
+ identity service.PendingAuthIdentityKey,
+ cfg wechatOAuthConfig,
+ openid string,
+) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ providerType := strings.TrimSpace(identity.ProviderType)
+ providerSubject := strings.TrimSpace(identity.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey)
+ if providerSubject != "" {
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+ }
+ }
+
+ openid = strings.TrimSpace(openid)
+ channel := strings.TrimSpace(cfg.mode)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if openid != "" && channel != "" && channelAppID != "" {
+ records, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(openid),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+ }
+ }
+
+ if openid == "" {
+ return nil, nil
+ }
+
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(openid),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ user, err := singleWeChatIdentityUser(records)
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+}
+
+func wechatCompatibleProviderKeys(providerKey string) []string {
+ preferred := strings.TrimSpace(providerKey)
+ if preferred == "" {
+ preferred = wechatOAuthProviderKey
+ }
+ keys := []string{preferred}
+ if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) {
+ keys = append(keys, wechatOAuthLegacyProviderKey)
+ }
+ return keys
+}
+
+func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.Identity.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.Identity.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding(
+ ctx context.Context,
+ userID int64,
+ identity service.PendingAuthIdentityKey,
+ upstreamClaims map[string]any,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims),
+ }, userID)
+ if err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
+ mode, err := resolveWeChatOAuthMode(rawMode, c)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+
+ if h == nil || h.settingSvc == nil {
+ return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready")
+ }
+
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ settings, err := h.settingSvc.GetAllSettings(ctx)
+ if err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+
+ effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+ if !effective.SupportsMode(mode) {
+ return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+
+ cfg := wechatOAuthConfig{
+ mode: mode,
+ appID: strings.TrimSpace(effective.AppIDForMode(mode)),
+ appSecret: strings.TrimSpace(effective.AppSecretForMode(mode)),
+ redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")),
+ frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB),
+ scope: effective.ScopeForMode(mode),
+ openEnabled: effective.OpenEnabled,
+ mpEnabled: effective.MPEnabled,
+ }
+
+ switch mode {
+ case "mp":
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
+ default:
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
+ }
+ if strings.TrimSpace(cfg.redirectURI) == "" {
+ return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
+ }
+
+ return cfg, nil
+}
+
+func (cfg wechatOAuthConfig) requiresUnionID() bool {
+ return cfg.openEnabled && cfg.mpEnabled
+}
+
+func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string {
+ if h != nil && h.settingSvc != nil {
+ cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" {
+ return strings.TrimSpace(cfg.FrontendRedirectURL)
+ }
+ }
+ return wechatOAuthDefaultFrontendCB
+}
+
+func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
+ mode := strings.ToLower(strings.TrimSpace(rawMode))
+ if mode == "" {
+ if isWeChatBrowserRequest(c) {
+ return "mp", nil
+ }
+ return "open", nil
+ }
+ if mode != "open" && mode != "mp" {
+ return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
+ }
+ return mode, nil
+}
+
+func isWeChatBrowserRequest(c *gin.Context) bool {
+ if c == nil || c.Request == nil {
+ return false
+ }
+ return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
+}
+
+func normalizeWeChatOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", "login":
+ return wechatOAuthIntentLogin
+ case "bind", "bind_current_user":
+ return wechatOAuthIntentBind
+ case "adopt", "adopt_existing_user_by_email":
+ return wechatOAuthIntentAdoptEmail
+ default:
+ return wechatOAuthIntentLogin
+ }
+}
+
+func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
+ u, err := url.Parse(cfg.authorizeURL)
+ if err != nil {
+ return "", fmt.Errorf("parse authorize url: %w", err)
+ }
+ query := u.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("redirect_uri", cfg.redirectURI)
+ query.Set("response_type", "code")
+ query.Set("scope", cfg.scope)
+ query.Set("state", state)
+ u.RawQuery = query.Encode()
+ u.Fragment = "wechat_redirect"
+ return u.String(), nil
+}
+
+func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
+ callbackPath = strings.TrimSpace(callbackPath)
+ if callbackPath == "" {
+ return ""
+ }
+
+ if raw := strings.TrimSpace(apiBaseURL); raw != "" {
+ if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ basePath := strings.TrimRight(parsed.EscapedPath(), "/")
+ targetPath := callbackPath
+ if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
+ targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
+ } else if basePath != "" {
+ targetPath = basePath + callbackPath
+ }
+ return parsed.Scheme + "://" + parsed.Host + targetPath
+ }
+ }
+
+ if c == nil || c.Request == nil {
+ return ""
+ }
+ scheme := "http"
+ if isRequestHTTPS(c) {
+ scheme = "https"
+ }
+ host := strings.TrimSpace(c.Request.Host)
+ if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
+ host = forwardedHost
+ }
+ if host == "" {
+ return ""
+ }
+ return scheme + "://" + host + callbackPath
+}
+
+func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
+ tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
+ if err != nil {
+ return nil, nil, err
+ }
+ userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
+ if err != nil {
+ return nil, nil, err
+ }
+ return tokenResp, userInfo, nil
+}
+
+func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
+ endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat access token url: %w", err)
+ }
+
+ query := endpoint.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("secret", cfg.appSecret)
+ query.Set("code", strings.TrimSpace(code))
+ query.Set("grant_type", "authorization_code")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat access token request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat access token: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat access token response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
+ }
+
+ var tokenResp wechatOAuthTokenResponse
+ if err := json.Unmarshal(body, &tokenResp); err != nil {
+ return nil, fmt.Errorf("decode wechat access token response: %w", err)
+ }
+ if tokenResp.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
+ }
+ if strings.TrimSpace(tokenResp.AccessToken) == "" {
+ return nil, fmt.Errorf("wechat access token missing access_token")
+ }
+ return &tokenResp, nil
+}
+
+func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
+ if tokenResp == nil {
+ return nil, fmt.Errorf("wechat token response is nil")
+ }
+
+ endpoint, err := url.Parse(wechatOAuthUserInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
+ }
+ query := endpoint.Query()
+ query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
+ query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
+ query.Set("lang", "zh_CN")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat userinfo request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat userinfo: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat userinfo response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
+ }
+
+ var userInfo wechatOAuthUserInfoResponse
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
+ }
+ if userInfo.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
+ }
+ return &userInfo, nil
+}
+
+func wechatSyntheticEmail(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return ""
+ }
+ return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
+}
+
+func wechatFallbackUsername(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return "wechat_user"
+ }
+ return "wechat_" + truncateFragmentValue(subject)
+}
+
+func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func normalizeWeChatPaymentType(raw string) string {
+ switch strings.TrimSpace(raw) {
+ case payment.TypeWxpay, payment.TypeWxpayDirect:
+ return strings.TrimSpace(raw)
+ default:
+ return ""
+ }
+}
+
+func normalizeWeChatPaymentScope(raw string) string {
+ for _, part := range strings.FieldsFunc(strings.TrimSpace(raw), func(r rune) bool {
+ return r == ',' || r == ' ' || r == '\t' || r == '\n' || r == '\r'
+ }) {
+ switch strings.TrimSpace(part) {
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ case "snsapi_base":
+ return "snsapi_base"
+ }
+ }
+ return "snsapi_base"
+}
+
+func normalizeWeChatPaymentRedirectPath(path string) string {
+ path = strings.TrimSpace(path)
+ if path == "" {
+ return wechatPaymentOAuthDefaultTo
+ }
+ if path == "/payment" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(path, "/payment?") {
+ return "/purchase" + strings.TrimPrefix(path, "/payment")
+ }
+ return path
+}
+
+func (h *AuthHandler) resolveWeChatPaymentOAuthCallbackURL(ctx context.Context, c *gin.Context) string {
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ if settings, err := h.settingSvc.GetAllSettings(ctx); err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+ return resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/payment/callback")
+}
+
+func encodeWeChatPaymentOAuthContext(ctx wechatPaymentOAuthContext) (string, error) {
+ data, err := json.Marshal(ctx)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func decodeWeChatPaymentOAuthContext(raw string) (wechatPaymentOAuthContext, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return wechatPaymentOAuthContext{}, nil
+ }
+ var ctx wechatPaymentOAuthContext
+ if err := json.Unmarshal([]byte(raw), &ctx); err != nil {
+ return wechatPaymentOAuthContext{}, err
+ }
+ return ctx, nil
+}
+
+func parseWeChatPaymentPlanID(raw string) int64 {
+ id, _ := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
+ return id
+}
+
+func wechatPaymentSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatPaymentOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatPaymentClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatPaymentOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
new file mode 100644
index 00000000..7cf114c1
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -0,0 +1,1497 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-open-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-open-secret",
+ service.SettingKeyWeChatConnectMode: "open",
+ service.SettingKeyWeChatConnectScopes: "snsapi_login",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "appid=wx-open-app")
+ require.Contains(t, location, "scope=snsapi_login")
+
+ cookies := recorder.Result().Cookies()
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
+ require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
+}
+
+func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-shared-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectOpenEnabled: "true",
+ service.SettingKeyWeChatConnectMPEnabled: "true",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "connect/qrconnect")
+ require.Contains(t, location, "scope=snsapi_login")
+}
+
+func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.Equal(t, "wechat-main", session.ProviderKey)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
+ require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
+}
+
+func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Equal(t, "openid-123", session.ProviderSubject)
+ require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail)
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(wechatSyntheticEmail("union-456")).
+ SetUsername("wechat-existing-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"username": "wechat-existing-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+}
+
+func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(wechatSyntheticEmail("union-disabled")).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-disabled").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
+ defer client.Close()
+ handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+ handler.cfg.Totp.EncryptionKeyConfigured = true
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+ require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
+ require.NotEmpty(t, fragment.Get("wechat_resume_token"))
+ require.Empty(t, fragment.Get("openid"))
+ require.Empty(t, fragment.Get("payment_type"))
+ require.Empty(t, fragment.Get("amount"))
+ require.Empty(t, fragment.Get("order_type"))
+ require.Empty(t, fragment.Get("plan_id"))
+
+ claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
+ require.NoError(t, err)
+ require.Equal(t, "openid-123", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "12.5", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 7, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+}
+
+func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
+ defer client.Close()
+
+ legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+ explicitSigningKey := "explicit-payment-resume-signing-key"
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey)
+ handler.cfg.Totp.EncryptionKey = legacyKeyHex
+ handler.cfg.Totp.EncryptionKeyConfigured = true
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+
+ token := fragment.Get("wechat_resume_token")
+ require.NotEmpty(t, token)
+
+ claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token)
+ require.NoError(t, err)
+ require.Equal(t, "openid-mixed-key", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "18.8", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 9, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+
+ _, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token)
+ require.Error(t, err)
+}
+
+func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
+ testCases := []struct {
+ name string
+ mode string
+ appID string
+ appSecret string
+ openID string
+ }{
+ {
+ name: "open",
+ mode: "open",
+ appID: "wx-open-app",
+ appSecret: "wx-open-secret",
+ openID: "openid-open-123",
+ },
+ {
+ name: "mp",
+ mode: "mp",
+ appID: "wx-mp-app",
+ appSecret: "wx-mp-secret",
+ openID: "openid-mp-123",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback"))
+ defer client.Close()
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(context.Background())
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, wechatOAuthIntentBind, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, currentUser.Email, session.ResolvedEmail)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"])
+ require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"])
+ require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"])
+ require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"])
+
+ completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completionResponse["redirect"])
+ _, hasAccessToken := completionResponse["access_token"]
+ require.False(t, hasAccessToken)
+ })
+ }
+}
+
+func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ ownerIdentity, err := client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-owner").
+ SetMetadata(map[string]any{"unionid": "union-owner"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(ownerIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetChannel("open").
+ SetChannelAppID("wx-open-app").
+ SetChannelSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, true)
+ defer client.Close()
+
+ ctx := context.Background()
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
+ Code: "invite-1",
+ Type: service.RedeemTypeInvitation,
+ Status: service.StatusUnused,
+ }))
+
+ callbackRecorder := httptest.NewRecorder()
+ callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
+ callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ callbackReq.Host = "api.example.com"
+ callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ callbackCtx.Request = callbackReq
+
+ handler.WeChatOAuthCallback(callbackCtx)
+
+ require.Equal(t, http.StatusFound, callbackRecorder.Code)
+ require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+ sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
+
+ pendingSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"])
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
+ completeRecorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(completeRecorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, completeRecorder.Code)
+ responseData := decodeJSONBody(t, completeRecorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["adoption_required"])
+ require.Empty(t, responseData["access_token"])
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, consumed.ConsumedAt)
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, channelCount)
+
+ decisionCount, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, decisionCount)
+}
+
+func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("wechat-subject-no-adoption").
+ SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid").
+ SetBrowserSessionKey("wechat-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ "suggested_display_name": "WeChat Legacy",
+ "suggested_avatar_url": "https://cdn.example/wechat-legacy.png",
+ "mode": "open",
+ "channel": "open",
+ "channel_app_id": "wx-open-app",
+ "channel_subject": "openid-legacy",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ openIDIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, openIDIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
+func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-invalid-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("wechat-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("wechat-choice-subject-1").
+ SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
+ SetBrowserSessionKey("wechat-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ legacyIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, legacyIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
+func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil)
+}
+
+func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string {
+ return map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: appID,
+ service.SettingKeyWeChatConnectAppSecret: secret,
+ service.SettingKeyWeChatConnectMode: mode,
+ service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode),
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect,
+ }
+}
+
+func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ userRepo := &oauthPendingFlowUserRepo{client: client}
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ values := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ }
+ for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") {
+ values[key] = value
+ }
+ for key, value := range extraSettings {
+ values[key] = value
+ }
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg)
+
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &wechatOAuthRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ cfg: cfg,
+ }, client
+}
+
+type wechatOAuthSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type wechatOAuthRefreshTokenCacheStub struct{}
+
+func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index d2ccb8d6..9780ff79 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -21,6 +21,7 @@ func UserFromServiceShallow(u *service.User) *User {
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
+ LastActiveAt: u.LastActiveAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
@@ -66,6 +67,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return &AdminUser{
User: *base,
Notes: u.Notes,
+ LastUsedAt: u.LastUsedAt,
GroupRates: u.GroupRates,
}
}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 3659e79b..fc6a3f9e 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -51,6 +51,23 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
+ WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
+ WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"`
+ WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
+ WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"`
+ WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
+ WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"`
+ WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
+ WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
+ WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
@@ -127,6 +144,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
@@ -167,6 +193,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
@@ -189,6 +216,10 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 8c1e166f..c0bce40b 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -7,16 +7,17 @@ import (
)
type User struct {
- ID int64 `json:"id"`
- Email string `json:"email"`
- Username string `json:"username"`
- Role string `json:"role"`
- Balance float64 `json:"balance"`
- Concurrency int `json:"concurrency"`
- Status string `json:"status"`
- AllowedGroups []int64 `json:"allowed_groups"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
+ ID int64 `json:"id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ Role string `json:"role"`
+ Balance float64 `json:"balance"`
+ Concurrency int `json:"concurrency"`
+ Status string `json:"status"`
+ AllowedGroups []int64 `json:"allowed_groups"`
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
@@ -34,7 +35,8 @@ type User struct {
type AdminUser struct {
User
- Notes string `json:"notes"`
+ Notes string `json:"notes"`
+ LastUsedAt *time.Time `json:"last_used_at"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go
new file mode 100644
index 00000000..a17f0ce4
--- /dev/null
+++ b/backend/internal/handler/dto/user_mapper_activity_test.go
@@ -0,0 +1,33 @@
+package dto
+
+import (
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
+ t.Parallel()
+
+ lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(15 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(45 * time.Minute)
+
+ out := UserFromServiceAdmin(&service.User{
+ ID: 42,
+ Email: "admin@example.com",
+ Username: "admin",
+ Role: service.RoleAdmin,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ })
+
+ require.NotNil(t, out)
+ require.NotNil(t, out.LastActiveAt)
+ require.NotNil(t, out.LastUsedAt)
+ require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go
index a897bc40..db29618a 100644
--- a/backend/internal/handler/endpoint.go
+++ b/backend/internal/handler/endpoint.go
@@ -15,10 +15,12 @@ import (
// ──────────────────────────────────────────────────────────
const (
- EndpointMessages = "/v1/messages"
- EndpointChatCompletions = "/v1/chat/completions"
- EndpointResponses = "/v1/responses"
- EndpointGeminiModels = "/v1beta/models"
+ EndpointMessages = "/v1/messages"
+ EndpointChatCompletions = "/v1/chat/completions"
+ EndpointResponses = "/v1/responses"
+ EndpointImagesGenerations = "/v1/images/generations"
+ EndpointImagesEdits = "/v1/images/edits"
+ EndpointGeminiModels = "/v1beta/models"
)
// gin.Context keys used by the middleware and helpers below.
@@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string {
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
return EndpointMessages
+ case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"):
+ return EndpointImagesGenerations
+ case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"):
+ return EndpointImagesEdits
case strings.Contains(path, EndpointResponses):
return EndpointResponses
case strings.Contains(path, EndpointGeminiModels):
@@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
switch platform {
case service.PlatformOpenAI:
+ if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
+ return inbound
+ }
// OpenAI forwards everything to the Responses API.
// Preserve subresource suffix (e.g. /v1/responses/compact).
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go
index 1519bc9e..369c5fa7 100644
--- a/backend/internal/handler/endpoint_test.go
+++ b/backend/internal/handler/endpoint_test.go
@@ -25,12 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/responses", EndpointResponses},
+ {"/v1/images/generations", EndpointImagesGenerations},
+ {"/v1/images/edits", EndpointImagesEdits},
{"/v1beta/models", EndpointGeminiModels},
// Prefixed paths (antigravity, openai).
{"/antigravity/v1/messages", EndpointMessages},
{"/openai/v1/responses", EndpointResponses},
{"/openai/v1/responses/compact", EndpointResponses},
+ {"/openai/v1/images/generations", EndpointImagesGenerations},
+ {"/openai/v1/images/edits", EndpointImagesEdits},
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
// Gin route patterns with wildcards.
@@ -73,6 +77,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
+ {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
+ {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 5319b55d..43999a01 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "previous_response_id_requires_wsv2"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2")
+ return
}
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -856,7 +861,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
if validation.HasItemReferenceForAllCallIDs {
@@ -866,7 +871,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index d299fb81..8ecee59a 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.Contains(t, w.Body.String(), "previous_response_id")
+}
+
+func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.NotContains(t, w.Body.String(), "reuse previous_response_id")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go
new file mode 100644
index 00000000..8dbf8935
--- /dev/null
+++ b/backend/internal/handler/openai_images.go
@@ -0,0 +1,300 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "strings"
+ "time"
+
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// Images handles OpenAI Images API requests.
+// POST /v1/images/generations
+// POST /v1/images/edits
+func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
+ streamStarted := false
+ defer h.recoverResponsesPanic(c, &streamStarted)
+
+ requestStart := time.Now()
+
+ apiKey, ok := middleware2.GetAPIKeyFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
+ return
+ }
+
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
+ return
+ }
+ reqLog := requestLogger(
+ c,
+ "handler.openai_gateway.images",
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ )
+ if !h.ensureResponsesDependencies(c, reqLog) {
+ return
+ }
+
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
+ if err != nil {
+ if maxErr, ok := extractMaxBytesError(err); ok {
+ h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
+ return
+ }
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
+ return
+ }
+ if len(body) == 0 {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
+ return
+ }
+
+ if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
+ setOpsRequestContext(c, "", false, nil)
+ } else {
+ setOpsRequestContext(c, "", false, body)
+ }
+
+ parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
+ if err != nil {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
+ return
+ }
+
+ reqLog = reqLog.With(
+ zap.String("model", parsed.Model),
+ zap.Bool("stream", parsed.Stream),
+ zap.Bool("multipart", parsed.Multipart),
+ zap.String("capability", string(parsed.RequiredCapability)),
+ )
+
+ if parsed.Multipart {
+ setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
+ } else {
+ setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
+ }
+ setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
+
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model)
+
+ if h.errorPassthroughService != nil {
+ service.BindErrorPassthroughService(c, h.errorPassthroughService)
+ }
+
+ subscription, _ := middleware2.GetSubscriptionFromContext(c)
+
+ service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
+ routingStart := time.Now()
+
+ userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog)
+ if !acquired {
+ return
+ }
+ if userReleaseFunc != nil {
+ defer userReleaseFunc()
+ }
+
+ if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
+ status, code, message := billingErrorDetails(err)
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
+ return
+ }
+
+ sessionHash := ""
+ if parsed.Multipart {
+ sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
+ } else {
+ sessionHash = h.gatewayService.GenerateSessionHash(c, body)
+ }
+
+ maxAccountSwitches := h.maxAccountSwitches
+ switchCount := 0
+ failedAccountIDs := make(map[int64]struct{})
+ sameAccountRetryCount := make(map[int64]int)
+ var lastFailoverErr *service.UpstreamFailoverError
+
+ for {
+ reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
+ selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages(
+ c.Request.Context(),
+ apiKey.GroupID,
+ sessionHash,
+ parsed.Model,
+ failedAccountIDs,
+ parsed.RequiredCapability,
+ )
+ if err != nil {
+ reqLog.Warn("openai.images.account_select_failed",
+ zap.Error(err),
+ zap.Int("excluded_account_count", len(failedAccountIDs)),
+ )
+ if len(failedAccountIDs) == 0 {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
+ return
+ }
+ if lastFailoverErr != nil {
+ h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
+ } else {
+ h.handleFailoverExhaustedSimple(c, 502, streamStarted)
+ }
+ return
+ }
+ if selection == nil || selection.Account == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
+ return
+ }
+
+ reqLog.Debug("openai.images.account_schedule_decision",
+ zap.String("layer", scheduleDecision.Layer),
+ zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
+ zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ zap.Int("top_k", scheduleDecision.TopK),
+ zap.Int64("latency_ms", scheduleDecision.LatencyMs),
+ zap.Float64("load_skew", scheduleDecision.LoadSkew),
+ )
+
+ account := selection.Account
+ sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
+ reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
+ setOpsSelectedAccount(c, account.ID, account.Platform)
+
+ accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog)
+ if !acquired {
+ return
+ }
+
+ service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
+ forwardStart := time.Now()
+ result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
+ forwardDurationMs := time.Since(forwardStart).Milliseconds()
+ if accountReleaseFunc != nil {
+ accountReleaseFunc()
+ }
+ upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
+ responseLatencyMs := forwardDurationMs
+ if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
+ responseLatencyMs = forwardDurationMs - upstreamLatencyMs
+ }
+ service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
+ if err == nil && result != nil && result.FirstTokenMs != nil {
+ service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
+ }
+ if err != nil {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai.images.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
+ }
+ }
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai.images.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ }
+ if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
+ reqLog.Warn("openai.images.forward_failed", fields...)
+ return
+ }
+ reqLog.Error("openai.images.forward_failed", fields...)
+ return
+ }
+
+ if result != nil {
+ if account.Type == service.AccountTypeOAuth {
+ h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ } else {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
+ }
+
+ userAgent := c.GetHeader("User-Agent")
+ clientIP := ip.GetClientIP(c)
+ requestPayloadHash := service.HashUsageRequestPayload(body)
+ if parsed.Multipart {
+ requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
+ }
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ InboundEndpoint: GetInboundEndpoint(c),
+ UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ RequestPayloadHash: requestPayloadHash,
+ APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
+ }); err != nil {
+ logger.L().With(
+ zap.String("component", "handler.openai_gateway.images"),
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ zap.String("model", parsed.Model),
+ zap.Int64("account_id", account.ID),
+ ).Error("openai.images.record_usage_failed", zap.Error(err))
+ }
+ })
+
+ reqLog.Debug("openai.images.request_completed",
+ zap.Int64("account_id", account.ID),
+ zap.Int("switch_count", switchCount),
+ )
+ return
+ }
+}
+
+func isMultipartImagesContentType(contentType string) bool {
+ return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data")
+}
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index 90e90dd0..93554912 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string {
return service.PlatformAntigravity
case strings.HasPrefix(p, "/v1beta/"):
return service.PlatformGemini
- case strings.Contains(p, "/responses"):
+ case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
return service.PlatformOpenAI
default:
return ""
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 1ddb8ae2..09580442 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -1,9 +1,14 @@
package handler
import (
+ "fmt"
"strconv"
"strings"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -202,10 +207,18 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) {
// CreateOrderRequest is the request body for creating a payment order.
type CreateOrderRequest struct {
- Amount float64 `json:"amount"`
- PaymentType string `json:"payment_type" binding:"required"`
- OrderType string `json:"order_type"`
- PlanID int64 `json:"plan_id"`
+ Amount float64 `json:"amount"`
+ PaymentType string `json:"payment_type" binding:"required"`
+ OpenID string `json:"openid"`
+ WechatResumeToken string `json:"wechat_resume_token"`
+ ReturnURL string `json:"return_url"`
+ PaymentSource string `json:"payment_source"`
+ OrderType string `json:"order_type"`
+ PlanID int64 `json:"plan_id"`
+ // IsMobile lets the frontend declare its mobile status directly. When
+ // nil we fall back to User-Agent heuristics (which miss iPadOS / some
+ // embedded browsers that strip the "Mobile" keyword).
+ IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -221,17 +234,36 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if strings.TrimSpace(req.WechatResumeToken) != "" {
+ claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+ mobile := isMobile(c)
+ if req.IsMobile != nil {
+ mobile = *req.IsMobile
+ }
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
- UserID: subject.UserID,
- Amount: req.Amount,
- PaymentType: req.PaymentType,
- ClientIP: c.ClientIP(),
- IsMobile: isMobile(c),
- SrcHost: c.Request.Host,
- SrcURL: c.Request.Referer(),
- OrderType: req.OrderType,
- PlanID: req.PlanID,
+ UserID: subject.UserID,
+ Amount: req.Amount,
+ PaymentType: req.PaymentType,
+ OpenID: req.OpenID,
+ ClientIP: c.ClientIP(),
+ IsMobile: mobile,
+ IsWeChatBrowser: isWeChatBrowser(c),
+ SrcHost: c.Request.Host,
+ SrcURL: c.Request.Referer(),
+ ReturnURL: req.ReturnURL,
+ PaymentSource: req.PaymentSource,
+ OrderType: req.OrderType,
+ PlanID: req.PlanID,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -240,6 +272,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.Success(c, result)
}
+func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
+ if req == nil || claims == nil {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
+ }
+ openid := strings.TrimSpace(claims.OpenID)
+ if openid == "" {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+
+ paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
+ if paymentType == "" {
+ paymentType = payment.TypeWxpay
+ }
+ if req.PaymentType != "" {
+ requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
+ if requestPaymentType != "" && requestPaymentType != paymentType {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
+ }
+ }
+ req.PaymentType = paymentType
+ req.OpenID = openid
+
+ if strings.TrimSpace(claims.Amount) != "" {
+ amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
+ if err != nil || amount <= 0 {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
+ }
+ req.Amount = amount
+ }
+ if claims.OrderType != "" {
+ req.OrderType = claims.OrderType
+ }
+ if claims.PlanID > 0 {
+ req.PlanID = claims.PlanID
+ }
+ return nil
+}
+
// GetMyOrders returns the authenticated user's orders.
// GET /api/v1/payment/orders/my
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
@@ -260,7 +330,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrder returns a single order for the authenticated user.
@@ -282,7 +352,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// CancelOrder cancels a pending order for the authenticated user.
@@ -354,6 +424,10 @@ type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"`
}
+type ResolveOrderByResumeTokenRequest struct {
+ ResumeToken string `json:"resume_token" binding:"required"`
+}
+
// VerifyOrder actively queries the upstream payment provider to check
// if payment was made, and processes it if so.
// POST /api/v1/payment/orders/verify
@@ -374,23 +448,57 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// PublicOrderResult is the limited order info returned by the public verify endpoint.
// No user details are exposed — only payment status information.
type PublicOrderResult struct {
- ID int64 `json:"id"`
- OutTradeNo string `json:"out_trade_no"`
- Amount float64 `json:"amount"`
- PayAmount float64 `json:"pay_amount"`
- PaymentType string `json:"payment_type"`
- OrderType string `json:"order_type"`
- Status string `json:"status"`
+ ID int64 `json:"id"`
+ OutTradeNo string `json:"out_trade_no"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ PaymentType string `json:"payment_type"`
+ OrderType string `json:"order_type"`
+ Status string `json:"status"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at"`
+ PaidAt *time.Time `json:"paid_at,omitempty"`
+ CompletedAt *time.Time `json:"completed_at,omitempty"`
+ RefundAmount float64 `json:"refund_amount"`
+ RefundReason *string `json:"refund_reason,omitempty"`
+ RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
+ RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
+ RefundRequestReason *string `json:"refund_request_reason,omitempty"`
+ PlanID *int64 `json:"plan_id,omitempty"`
}
-// VerifyOrderPublic verifies payment status without requiring authentication.
-// Returns limited order info (no user details) to prevent information leakage.
+func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
+ return PublicOrderResult{
+ ID: order.ID,
+ OutTradeNo: order.OutTradeNo,
+ Amount: order.Amount,
+ PayAmount: order.PayAmount,
+ FeeRate: order.FeeRate,
+ PaymentType: order.PaymentType,
+ OrderType: order.OrderType,
+ Status: order.Status,
+ CreatedAt: order.CreatedAt,
+ ExpiresAt: order.ExpiresAt,
+ PaidAt: order.PaidAt,
+ CompletedAt: order.CompletedAt,
+ RefundAmount: order.RefundAmount,
+ RefundReason: order.RefundReason,
+ RefundRequestedAt: order.RefundRequestedAt,
+ RefundRequestedBy: order.RefundRequestedBy,
+ RefundRequestReason: order.RefundRequestReason,
+ PlanID: order.PlanID,
+ }
+}
+
+// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
+// a compatibility path for older result pages and staggered deploys.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
var req VerifyOrderRequest
@@ -398,20 +506,30 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+
order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, PublicOrderResult{
- ID: order.ID,
- OutTradeNo: order.OutTradeNo,
- Amount: order.Amount,
- PayAmount: order.PayAmount,
- PaymentType: order.PaymentType,
- OrderType: order.OrderType,
- Status: order.Status,
- })
+ response.Success(c, buildPublicOrderResult(order))
+}
+
+// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
+// POST /api/v1/payment/public/orders/resolve
+func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
+ var req ResolveOrderByResumeTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, buildPublicOrderResult(order))
}
// requireAuth extracts the authenticated subject from the context.
@@ -435,3 +553,27 @@ func isMobile(c *gin.Context) bool {
}
return false
}
+
+func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizePaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
+func isWeChatBrowser(c *gin.Context) bool {
+ return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger")
+}
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
new file mode 100644
index 00000000..a7bc4ba3
--- /dev/null
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -0,0 +1,368 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ Amount: 0,
+ PaymentType: payment.TypeWxpay,
+ OrderType: payment.OrderTypeBalance,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ })
+ if err != nil {
+ t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
+ }
+ if req.OpenID != "openid-123" {
+ t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
+ }
+ if req.Amount != 12.5 {
+ t.Fatalf("amount = %v, want 12.5", req.Amount)
+ }
+ if req.OrderType != payment.OrderTypeSubscription {
+ t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
+ }
+ if req.PlanID != 7 {
+ t.Fatalf("plan_id = %d, want 7", req.PlanID)
+ }
+}
+
+func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ PaymentType: payment.TypeAlipay,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeBalance,
+ })
+ if err == nil {
+ t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
+ }
+}
+
+func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-verify@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-verify-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(90.64).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-VERIFY").
+ SetOutTradeNo("legacy-order-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-verify").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ID int64 `json:"id"`
+ OutTradeNo string `json:"out_trade_no"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ PaymentType string `json:"payment_type"`
+ OrderType string `json:"order_type"`
+ Status string `json:"status"`
+ RefundAmount float64 `json:"refund_amount"`
+ CreatedAt string `json:"created_at"`
+ ExpiresAt string `json:"expires_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, order.ID, resp.Data.ID)
+ require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
+ require.Equal(t, 90.64, resp.Data.PayAmount)
+ require.Equal(t, 0.03, resp.Data.FeeRate)
+ require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
+ require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
+ require.Equal(t, service.OrderStatusPending, resp.Data.Status)
+ require.Equal(t, 0.0, resp.Data.RefundAmount)
+ require.NotEmpty(t, resp.Data.CreatedAt)
+ require.NotEmpty(t, resp.Data.ExpiresAt)
+}
+
+func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-resolve@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-resolve-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(100).
+ SetPayAmount(103).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-RESOLVE").
+ SetOutTradeNo("resolve-order-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-resolve").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPaid).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/resolve",
+ bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.ResolveOrderPublicByResumeToken(ctx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, float64(order.ID), resp.Data["id"])
+ require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"])
+ require.Equal(t, 100.0, resp.Data["amount"])
+ require.Equal(t, 103.0, resp.Data["pay_amount"])
+ require.Equal(t, 0.03, resp.Data["fee_rate"])
+ require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
+ require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
+ require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
+ require.Contains(t, resp.Data, "created_at")
+ require.Contains(t, resp.Data, "expires_at")
+ require.Contains(t, resp.Data, "refund_amount")
+}
+
+func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-resolve-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-resolve-mismatch-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(100).
+ SetPayAmount(103).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
+ SetOutTradeNo("resolve-order-mismatch-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-resolve-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPaid).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID + 999,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/resolve",
+ bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.ResolveOrderPublicByResumeToken(ctx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
+}
+
+func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":" "}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
+}
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
index 8a83bfeb..c06a5b7e 100644
--- a/backend/internal/handler/payment_webhook_handler.go
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -1,6 +1,8 @@
package handler
import (
+ "context"
+ "fmt"
"io"
"log/slog"
"net/http"
@@ -77,9 +79,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
outTradeNo := extractOutTradeNo(rawBody, providerKey)
- provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
+ providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
if err != nil {
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
+ if providerKey == payment.TypeWxpay {
+ c.String(http.StatusBadRequest, "verify failed")
+ return
+ }
writeSuccessResponse(c, providerKey)
return
}
@@ -89,7 +95,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
headers[strings.ToLower(k)] = c.GetHeader(k)
}
- notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
+ resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
if err != nil {
truncatedBody := rawBody
if len(truncatedBody) > webhookLogTruncateLen {
@@ -103,24 +109,24 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
if notification == nil {
- writeSuccessResponse(c, providerKey)
+ writeSuccessResponse(c, resolvedProviderKey)
return
}
- if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
- slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
+ if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
+ slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
c.String(http.StatusInternalServerError, "handle failed")
return
}
- writeSuccessResponse(c, providerKey)
+ writeSuccessResponse(c, resolvedProviderKey)
}
// extractOutTradeNo parses the webhook body to find the out_trade_no.
// This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey {
- case payment.TypeEasyPay:
+ case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody)
if err == nil {
return values.Get("out_trade_no")
@@ -131,6 +137,25 @@ func extractOutTradeNo(rawBody, providerKey string) string {
return ""
}
+func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
+ var lastErr error
+ for _, provider := range providers {
+ if provider == nil {
+ continue
+ }
+ notification, err := provider.VerifyNotification(ctx, rawBody, headers)
+ if err != nil {
+ lastErr = err
+ continue
+ }
+ return provider.ProviderKey(), notification, nil
+ }
+ if lastErr != nil {
+ return "", nil, lastErr
+ }
+ return "", nil, fmt.Errorf("no webhook provider could verify notification")
+}
+
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type wxpaySuccessResponse struct {
Code string `json:"code"`
diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go
index bdef1766..88221b5c 100644
--- a/backend/internal/handler/payment_webhook_handler_test.go
+++ b/backend/internal/handler/payment_webhook_handler_test.go
@@ -3,11 +3,14 @@
package handler
import (
+ "context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -97,3 +100,104 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
+
+func TestExtractOutTradeNo(t *testing.T) {
+ tests := []struct {
+ name string
+ providerKey string
+ rawBody string
+ want string
+ }{
+ {
+ name: "easypay query payload",
+ providerKey: "easypay",
+ rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
+ want: "sub2_123",
+ },
+ {
+ name: "alipay query payload",
+ providerKey: "alipay",
+ rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
+ want: "sub2_456",
+ },
+ {
+ name: "unknown provider",
+ providerKey: "wxpay",
+ rawBody: "{}",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
+ })
+ }
+}
+
+func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
+ firstErr := errors.New("wrong provider")
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: firstErr,
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ notification: &payment.PaymentNotification{
+ OrderID: "sub2_42",
+ TradeNo: "trade-42",
+ Status: payment.NotificationStatusSuccess,
+ },
+ },
+ }
+
+ providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
+ require.NoError(t, err)
+ require.Equal(t, payment.TypeWxpay, providerKey)
+ require.NotNil(t, notification)
+ require.Equal(t, "sub2_42", notification.OrderID)
+}
+
+func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed a"),
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed b"),
+ },
+ }
+
+ _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
+ require.Error(t, err)
+}
+
+type webhookHandlerProviderStub struct {
+ key string
+ notification *payment.PaymentNotification
+ verifyErr error
+}
+
+func (p webhookHandlerProviderStub) Name() string { return p.key }
+func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
+func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.PaymentType(p.key)}
+}
+func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ if p.verifyErr != nil {
+ return nil, p.verifyErr
+ }
+ return p.notification, nil
+}
+func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 1717b7a1..c0f5c28b 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
@@ -56,6 +57,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
+ WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
+ WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
+ WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go
new file mode 100644
index 00000000..45d66f8e
--- /dev/null
+++ b/backend/internal/handler/setting_handler_public_test.go
@@ -0,0 +1,122 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerPublicRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-mp-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectOpenEnabled: "true",
+ service.SettingKeyWeChatConnectMPEnabled: "true",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.WeChatOAuthEnabled)
+ require.True(t, resp.Data.WeChatOAuthOpenEnabled)
+ require.True(t, resp.Data.WeChatOAuthMPEnabled)
+}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 2535ea5e..f74c2b72 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -1,6 +1,9 @@
package handler
import (
+ "context"
+ "strings"
+
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -12,14 +15,21 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
+ authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
}
// NewUserHandler creates a new UserHandler
-func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
+func NewUserHandler(
+ userService *service.UserService,
+ authService *service.AuthService,
+ emailService *service.EmailService,
+ emailCache service.EmailCache,
+) *UserHandler {
return &UserHandler{
userService: userService,
+ authService: authService,
emailService: emailService,
emailCache: emailCache,
}
@@ -34,10 +44,33 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type userProfileResponse struct {
+ dto.User
+ AvatarURL string `json:"avatar_url,omitempty"`
+ AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
+ UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
+ DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
+ NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
+ ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
+ Identities service.UserIdentitySummarySet `json:"identities"`
+ AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
+ IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
+ EmailBound bool `json:"email_bound"`
+ LinuxDoBound bool `json:"linuxdo_bound"`
+ OIDCBound bool `json:"oidc_bound"`
+ WeChatBound bool `json:"wechat_bound"`
+}
+
+type userProfileSourceContext struct {
+ Provider string `json:"provider,omitempty"`
+ Source string `json:"source,omitempty"`
+}
+
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
@@ -47,13 +80,19 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
- userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, dto.UserFromService(userData))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// ChangePassword handles changing user password
@@ -101,6 +140,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
+ AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
@@ -110,7 +150,155 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+type StartIdentityBindingRequest struct {
+ Provider string `json:"provider" binding:"required"`
+ RedirectTo string `json:"redirect_to"`
+}
+
+type BindEmailIdentityRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code" binding:"required"`
+ Password string `json:"password" binding:"required"`
+}
+
+type SendEmailBindingCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
+// POST /api/v1/user/auth-identities/bind/start
+func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
+ if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req StartIdentityBindingRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{
+ Provider: req.Provider,
+ RedirectTo: req.RedirectTo,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, result)
+}
+
+// BindEmailIdentity verifies and binds a local email identity for the current user.
+// POST /api/v1/user/account-bindings/email
+func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req BindEmailIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ updatedUser, err := h.authService.BindEmailIdentity(
+ c.Request.Context(),
+ subject.UserID,
+ req.Email,
+ req.VerifyCode,
+ req.Password,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// UnbindIdentity removes a third-party sign-in provider from the current user.
+// DELETE /api/v1/user/account-bindings/:provider
+func (h *UserHandler) UnbindIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult(
+ c.Request.Context(),
+ subject.UserID,
+ c.Param("provider"),
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if unbound && h.authService != nil {
+ if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// SendEmailBindingCode sends a verification code for the current user's email binding flow.
+// POST /api/v1/user/account-bindings/email/send-code
+func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req SendEmailBindingCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
@@ -176,7 +364,13 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
@@ -212,7 +406,13 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
@@ -248,5 +448,117 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) {
+ identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user)
+ if err != nil {
+ return userProfileResponse{}, err
+ }
+ return userProfileResponseFromService(user, identities), nil
+}
+
+func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse {
+ base := dto.UserFromService(user)
+ if base == nil {
+ return userProfileResponse{}
+ }
+ bindings := userProfileBindingMap(identities)
+ profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
+ return userProfileResponse{
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ AvatarSource: avatarSource,
+ UsernameSource: usernameSource,
+ DisplayNameSource: usernameSource,
+ NicknameSource: usernameSource,
+ ProfileSources: profileSources,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
+ }
+}
+
+func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
+ return map[string]service.UserIdentitySummary{
+ "email": identities.Email,
+ "linuxdo": identities.LinuxDo,
+ "oidc": identities.OIDC,
+ "wechat": identities.WeChat,
+ }
+}
+
+func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
+ map[string]*userProfileSourceContext,
+ *userProfileSourceContext,
+ *userProfileSourceContext,
+) {
+ if user == nil {
+ return nil, nil, nil
+ }
+
+ thirdParty := thirdPartyIdentityProviders(identities)
+ var avatarSource *userProfileSourceContext
+ avatarValue := strings.TrimSpace(user.AvatarURL)
+ for _, summary := range thirdParty {
+ if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) {
+ avatarSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+
+ usernameValue := strings.TrimSpace(user.Username)
+ var usernameSource *userProfileSourceContext
+ for _, summary := range thirdParty {
+ if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
+ usernameSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+
+ profileSources := map[string]*userProfileSourceContext{}
+ if avatarSource != nil {
+ profileSources["avatar"] = avatarSource
+ }
+ if usernameSource != nil {
+ profileSources["username"] = usernameSource
+ profileSources["display_name"] = usernameSource
+ profileSources["nickname"] = usernameSource
+ }
+ if len(profileSources) == 0 {
+ return nil, avatarSource, usernameSource
+ }
+ return profileSources, avatarSource, usernameSource
+}
+
+func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
+ out := make([]service.UserIdentitySummary, 0, 3)
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ if summary.Bound {
+ out = append(out, summary)
+ }
+ }
+ return out
+}
+
+func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
+ provider = strings.TrimSpace(provider)
+ if provider == "" {
+ return nil
+ }
+ return &userProfileSourceContext{
+ Provider: provider,
+ Source: provider,
+ }
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
new file mode 100644
index 00000000..a655b81c
--- /dev/null
+++ b/backend/internal/handler/user_handler_test.go
@@ -0,0 +1,783 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type userHandlerRepoStub struct {
+ user *service.User
+ identities []service.UserAuthIdentityRecord
+ unbound []string
+}
+
+func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
+func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
+ cloned := *user
+ s.user = &cloned
+ return nil
+}
+func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ if s.user == nil || s.user.AvatarURL == "" {
+ return nil, nil
+ }
+ return &service.UserAvatar{
+ StorageProvider: s.user.AvatarSource,
+ URL: s.user.AvatarURL,
+ ContentType: s.user.AvatarMIME,
+ ByteSize: s.user.AvatarByteSize,
+ SHA256: s.user.AvatarSHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ s.user.AvatarURL = input.URL
+ s.user.AvatarSource = input.StorageProvider
+ s.user.AvatarMIME = input.ContentType
+ s.user.AvatarByteSize = input.ByteSize
+ s.user.AvatarSHA256 = input.SHA256
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ s.user.AvatarURL = ""
+ s.user.AvatarSource = ""
+ s.user.AvatarMIME = ""
+ s.user.AvatarByteSize = 0
+ s.user.AvatarSHA256 = ""
+ return nil
+}
+func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error {
+ if s.user != nil {
+ s.user.LastActiveAt = &activeAt
+ }
+ return nil
+}
+func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
+ out := make([]service.UserAuthIdentityRecord, len(s.identities))
+ copy(out, s.identities)
+ return out, nil
+}
+func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
+ s.unbound = append(s.unbound, provider)
+ filtered := s.identities[:0]
+ for _, identity := range s.identities {
+ if identity.ProviderType == provider {
+ continue
+ }
+ filtered = append(filtered, identity)
+ }
+ s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...)
+ return nil
+}
+
+func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "handler-avatar@example.com",
+ Username: "handler-avatar",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
+
+ body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.UpdateProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ AvatarURL string `json:"avatar_url"`
+ Username string `json:"username"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
+ require.Equal(t, "handler-avatar", resp.Data.Username)
+}
+
+func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-123456",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ {
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example.com",
+ ProviderSubject: "oidc-user-abc",
+ Metadata: map[string]any{
+ "suggested_display_name": "OIDC Display",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Identities struct {
+ Email struct {
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name"`
+ } `json:"email"`
+ LinuxDo struct {
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name"`
+ ProviderKey string `json:"provider_key"`
+ } `json:"linuxdo"`
+ OIDC struct {
+ Bound bool `json:"bound"`
+ DisplayName string `json:"display_name"`
+ ProviderKey string `json:"provider_key"`
+ } `json:"oidc"`
+ WeChat struct {
+ Bound bool `json:"bound"`
+ CanBind bool `json:"can_bind"`
+ BindStartPath string `json:"bind_start_path"`
+ } `json:"wechat"`
+ } `json:"identities"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.Identities.Email.Bound)
+ require.Equal(t, 1, resp.Data.Identities.Email.BoundCount)
+ require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName)
+ require.True(t, resp.Data.Identities.LinuxDo.Bound)
+ require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount)
+ require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName)
+ require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey)
+ require.True(t, resp.Data.Identities.OIDC.Bound)
+ require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName)
+ require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey)
+ require.False(t, resp.Data.Identities.WeChat.Bound)
+ require.True(t, resp.Data.Identities.WeChat.CanBind)
+ require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start")
+}
+
+func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "legacy-profile@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, false, resp.Data["oidc_bound"])
+ require.Equal(t, false, resp.Data["wechat_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+ require.Equal(t, "linuxdo", linuxdoBinding["provider"])
+
+ identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
+ require.True(t, ok)
+ emailBinding, ok := identityBindings["email"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, emailBinding["bound"])
+ require.Equal(t, "profile.authBindings.notes.emailManagedFromProfile", emailBinding["note_key"])
+
+ linuxdoCompatBinding, ok := identityBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "profile.authBindings.notes.canUnbind", linuxdoCompatBinding["note_key"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
+}
+
+func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 22,
+ Email: "edited-profile@example.com",
+ Username: "custom-name",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/custom.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-22",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.NotContains(t, resp.Data, "avatar_source")
+ require.NotContains(t, resp.Data, "username_source")
+ require.NotContains(t, resp.Data, "profile_sources")
+}
+
+type userHandlerEmailCacheStub struct {
+ data *service.VerificationCodeData
+}
+
+type userHandlerRefreshTokenCacheStub struct {
+ revokedUserIDs []int64
+}
+
+func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
+ s.revokedUserIDs = append(s.revokedUserIDs, userID)
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return s.data, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Params = gin.Params{{Key: "provider", Value: "email"}}
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Email string `json:"email"`
+ EmailBound bool `json:"email_bound"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "new@example.com", resp.Data.Email)
+ require.True(t, resp.Data.EmailBound)
+}
+
+func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []string{"linuxdo"}, repo.unbound)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, linuxdoBinding["bound"])
+}
+
+func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 23,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-23",
+ },
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(5), repo.user.TokenVersion)
+}
+
+func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 24,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Empty(t, repo.unbound)
+ require.Empty(t, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(4), repo.user.TokenVersion)
+}
+
+func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ user := &service.User{
+ ID: 11,
+ Email: "current@example.com",
+ Username: "bound-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, user.SetPassword("current-password"))
+
+ repo := &userHandlerRepoStub{user: user}
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "PASSWORD_INCORRECT", resp.Reason)
+ require.Equal(t, "current password is incorrect", resp.Message)
+ require.Equal(t, "current@example.com", repo.user.Email)
+}
+
+func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
+
+ body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.StartIdentityBinding(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Provider string `json:"provider"`
+ AuthorizeURL string `json:"authorize_url"`
+ Method string `json:"method"`
+ UseBrowserRedirect bool `json:"use_browser_redirect"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "wechat", resp.Data.Provider)
+ require.Equal(t, "GET", resp.Data.Method)
+ require.True(t, resp.Data.UseBrowserRedirect)
+ require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start")
+ require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
+ require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
+}
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
index e39e957f..0581469d 100644
--- a/backend/internal/payment/crypto.go
+++ b/backend/internal/payment/crypto.go
@@ -10,12 +10,20 @@ import (
"strings"
)
+// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
+const AES256KeySize = 32
+
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function is kept only for seeding legacy ciphertext in tests and for
+// the transitional Decrypt fallback. Scheduled for removal after all live
+// deployments complete migration by re-saving their configs.
func Encrypt(plaintext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
block, err := aes.NewCipher(key)
@@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function remains only as a read-path fallback for pre-migration
+// ciphertext records. Scheduled for removal once all deployments re-save
+// their provider configs through the admin UI.
func Decrypt(ciphertext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
parts := strings.SplitN(ciphertext, ":", 3)
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index f0353173..41fd2c50 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -45,11 +45,31 @@ type DefaultLoadBalancer struct {
counter atomic.Uint64
}
+type contextKey string
+
+const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
+
// NewDefaultLoadBalancer creates a new load balancer.
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
}
+func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
+ appID = strings.TrimSpace(appID)
+ if appID == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
+}
+
+func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
+ return strings.TrimSpace(appID)
+}
+
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type instanceCandidate struct {
inst *dbent.PaymentProviderInstance
@@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
}
var matched []*dbent.PaymentProviderInstance
+ expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
for _, inst := range instances {
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
@@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
+ if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
+ config, cfgErr := lb.decryptConfig(inst.Config)
+ if cfgErr != nil {
+ slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
+ continue
+ }
+ if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
+ continue
+ }
+ }
matched = append(matched, inst)
}
}
@@ -231,6 +262,11 @@ func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType P
if cl, ok := limits[lookupKey]; ok {
return cl
}
+ if aliasKey := legacyVisibleMethodAlias(lookupKey); aliasKey != "" {
+ if cl, ok := limits[aliasKey]; ok {
+ return cl
+ }
+ }
return ChannelLimits{}
}
@@ -261,6 +297,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
if err != nil {
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
}
+ if config == nil {
+ config = map[string]string{}
+ }
if selected.PaymentMode != "" {
config["paymentMode"] = selected.PaymentMode
@@ -275,16 +314,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
}, nil
}
-func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
- plaintext, err := Decrypt(encrypted, lb.encryptionKey)
- if err != nil {
- return nil, err
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
+// Unreadable values (legacy ciphertext without a valid key, or malformed data)
+// are treated as empty so the service keeps running while the admin re-enters
+// the config via the UI.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
+// transitional compatibility shim for pre-plaintext records. Remove it (and
+// the encryptionKey field + the Decrypt import) after a few releases once all
+// live deployments have re-saved their provider configs through the UI.
+func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
+ return nil, nil
}
var config map[string]string
- if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
- return nil, fmt.Errorf("unmarshal config: %w", err)
+ if err := json.Unmarshal([]byte(stored), &config); err == nil {
+ return config, nil
}
- return config, nil
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
+ if len(lb.encryptionKey) == AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
+ if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
+ return config, nil
+ }
+ }
+ }
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
@@ -321,14 +380,45 @@ func InstanceSupportsType(supportedTypes string, target PaymentType) bool {
if supportedTypes == "" {
return true
}
+ normalizedTarget := normalizeVisibleMethodSupportType(target)
for _, t := range strings.Split(supportedTypes, ",") {
- if strings.TrimSpace(t) == target {
+ supported := strings.TrimSpace(t)
+ if supported == target || normalizeVisibleMethodSupportType(supported) == normalizedTarget {
return true
}
}
return false
}
+func normalizeVisibleMethodSupportType(paymentType PaymentType) PaymentType {
+ switch strings.TrimSpace(paymentType) {
+ case TypeAlipay, TypeAlipayDirect:
+ return TypeAlipay
+ case TypeWxpay, TypeWxpayDirect:
+ return TypeWxpay
+ default:
+ return strings.TrimSpace(paymentType)
+ }
+}
+
+func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
+ switch normalizeVisibleMethodSupportType(paymentType) {
+ case TypeAlipay:
+ return TypeAlipayDirect
+ case TypeWxpay:
+ return TypeWxpayDirect
+ default:
+ return ""
+ }
+}
+
+func resolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go
index 04b3c25b..ed08a7dd 100644
--- a/backend/internal/payment/load_balancer_test.go
+++ b/backend/internal/payment/load_balancer_test.go
@@ -68,10 +68,16 @@ func TestInstanceSupportsType(t *testing.T) {
expected: true,
},
{
- name: "partial match should not succeed",
+ name: "legacy alipay direct supports canonical visible method",
supportedTypes: "alipay_direct",
target: "alipay",
- expected: false,
+ expected: true,
+ },
+ {
+ name: "legacy wxpay direct supports canonical visible method",
+ supportedTypes: "wxpay_direct",
+ target: "wxpay",
+ expected: true,
},
{
name: "empty supported types means all supported",
@@ -92,6 +98,22 @@ func TestInstanceSupportsType(t *testing.T) {
}
}
+func TestGetInstanceChannelLimitsFallsBackToLegacyDirectAliases(t *testing.T) {
+ t.Parallel()
+
+ inst := testInstance(1, TypeAlipay, makeLimitsJSON(TypeAlipayDirect, ChannelLimits{SingleMax: 66}))
+ got := getInstanceChannelLimits(inst, TypeAlipay)
+ if got.SingleMax != 66 {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMax=66", got)
+ }
+
+ wxInst := testInstance(2, TypeWxpay, makeLimitsJSON(TypeWxpayDirect, ChannelLimits{SingleMin: 8}))
+ wxGot := getInstanceChannelLimits(wxInst, TypeWxpay)
+ if wxGot.SingleMin != 8 {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMin=8", wxGot)
+ }
+}
+
// ---------------------------------------------------------------------------
// Helper to build test PaymentProviderInstance values
// ---------------------------------------------------------------------------
@@ -452,6 +474,103 @@ func TestStartOfDay(t *testing.T) {
}
}
+func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
+ t.Parallel()
+
+ key := make([]byte, AES256KeySize)
+ for i := range key {
+ key[i] = byte(i + 1)
+ }
+ wrongKey := make([]byte, AES256KeySize)
+ for i := range wrongKey {
+ wrongKey[i] = byte(0xFF - i)
+ }
+
+ plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
+
+ legacyEncrypted, err := Encrypt(plaintextJSON, key)
+ if err != nil {
+ t.Fatalf("seed Encrypt: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ stored string
+ key []byte
+ want map[string]string
+ }{
+ {
+ name: "empty stored returns nil map",
+ stored: "",
+ key: key,
+ want: nil,
+ },
+ {
+ name: "plaintext JSON parses directly",
+ stored: plaintextJSON,
+ key: nil,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "plaintext JSON works even with key present",
+ stored: plaintextJSON,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with correct key decrypts",
+ stored: legacyEncrypted,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with no key treated as empty",
+ stored: legacyEncrypted,
+ key: nil,
+ want: nil,
+ },
+ {
+ name: "legacy ciphertext with wrong key treated as empty",
+ stored: legacyEncrypted,
+ key: wrongKey,
+ want: nil,
+ },
+ {
+ name: "garbage data treated as empty",
+ stored: "not-json-and-not-ciphertext",
+ key: key,
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ lb := NewDefaultLoadBalancer(nil, tt.key)
+ got, err := lb.decryptConfig(tt.stored)
+ if err != nil {
+ t.Fatalf("decryptConfig unexpected error: %v", err)
+ }
+ if !stringMapEqual(got, tt.want) {
+ t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// stringMapEqual compares two map[string]string values; nil and empty are equal.
+func stringMapEqual(a, b map[string]string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, v := range a {
+ if bv, ok := b[k]; !ok || bv != v {
+ return false
+ }
+ }
+ return true
+}
+
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index af8a90c6..4a260295 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -15,8 +15,8 @@ import (
// Alipay product codes.
const (
- alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay = "QUICK_WAP_WAY"
+ alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
)
// Alipay response constants.
@@ -26,6 +26,15 @@ const (
alipayRefundSuffix = "-refund"
)
+var (
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ return client.TradeWapPay(param)
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ return client.TradePagePay(param)
+ }
+)
+
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type Alipay struct {
instanceID string
@@ -79,7 +88,23 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
-// CreatePayment creates an Alipay payment page URL.
+func (a *Alipay) MerchantIdentityMetadata() map[string]string {
+ if a == nil {
+ return nil
+ }
+ appID := strings.TrimSpace(a.config["appId"])
+ if appID == "" {
+ return nil
+ }
+ return map[string]string{"app_id": appID}
+}
+
+// CreatePayment creates an Alipay payment using redirect-only flow:
+// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
+// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
+// new window; Alipay's own page then shows login/QR. We intentionally do
+// NOT encode the URL into a QR on the client (it isn't a scannable payload
+// and would produce an invalid scan result).
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
@@ -96,31 +121,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
}
if req.IsMobile {
- return a.createTrade(client, req, notifyURL, returnURL, true)
+ return a.createWapTrade(client, req, notifyURL, returnURL)
}
- return a.createTrade(client, req, notifyURL, returnURL, false)
+ return a.createPagePayTrade(client, req, notifyURL, returnURL)
}
-func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
- if isMobile {
- param := alipay.TradeWapPay{}
- param.OutTradeNo = req.OrderID
- param.TotalAmount = req.Amount
- param.Subject = req.Subject
- param.ProductCode = alipayProductCodeWapPay
- param.NotifyURL = notifyURL
- param.ReturnURL = returnURL
+func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradeWapPay{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodeWapPay
+ param.NotifyURL = notifyURL
+ param.ReturnURL = returnURL
- payURL, err := client.TradeWapPay(param)
- if err != nil {
- return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
- }
- return &payment.CreatePaymentResponse{
- TradeNo: req.OrderID,
- PayURL: payURL.String(),
- }, nil
+ payURL, err := alipayTradeWapPay(client, param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ PayURL: payURL.String(),
+ }, nil
+}
+func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
param := alipay.TradePagePay{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
@@ -129,14 +154,13 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
- payURL, err := client.TradePagePay(param)
+ payURL, err := alipayTradePagePay(client, param)
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
- QRCode: payURL.String(),
}, nil
}
@@ -172,10 +196,11 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
}
return &payment.QueryOrderResponse{
- TradeNo: result.TradeNo,
- Status: status,
- Amount: amount,
- PaidAt: result.SendPayDate,
+ TradeNo: result.TradeNo,
+ Status: status,
+ Amount: amount,
+ PaidAt: result.SendPayDate,
+ Metadata: a.MerchantIdentityMetadata(),
}, nil
}
@@ -206,12 +231,21 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
}
+ metadata := a.MerchantIdentityMetadata()
+ if appID := strings.TrimSpace(notification.AppId); appID != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["app_id"] = appID
+ }
+
return &payment.PaymentNotification{
- TradeNo: notification.TradeNo,
- OrderID: notification.OutTradeNo,
- Amount: amount,
- Status: status,
- RawData: rawBody,
+ TradeNo: notification.TradeNo,
+ OrderID: notification.OutTradeNo,
+ Amount: amount,
+ Status: status,
+ RawData: rawBody,
+ Metadata: metadata,
}, nil
}
@@ -274,6 +308,7 @@ func isTradeNotExist(err error) bool {
// Ensure interface compliance.
var (
- _ payment.Provider = (*Alipay)(nil)
- _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.Provider = (*Alipay)(nil)
+ _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.MerchantIdentityProvider = (*Alipay)(nil)
)
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 7b0ce0d8..8b3ff8ce 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -4,8 +4,12 @@ package provider
import (
"errors"
+ "net/url"
"strings"
"testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/smartwalle/alipay/v3"
)
func TestIsTradeNotExist(t *testing.T) {
@@ -130,3 +134,96 @@ func TestNewAlipay(t *testing.T) {
})
}
}
+
+func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
+ origPagePay := alipayTradePagePay
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePagePay = origPagePay
+ alipayTradeWapPay = origWapPay
+ })
+
+ pagePayCalls := 0
+ wapPayCalls := 0
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ if param.OutTradeNo != "sub2_100" {
+ t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100")
+ }
+ if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" {
+ t.Fatalf("notify_url = %q", param.NotifyURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createPagePayTrade(&alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "88.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if pagePayCalls != 1 {
+ t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
+ }
+ if wapPayCalls != 0 {
+ t.Fatalf("wap pay calls = %d, want 0", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for desktop page pay")
+ }
+}
+
+func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradeWapPay = origWapPay
+ })
+
+ wapPayCalls := 0
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ if param.ReturnURL != "https://merchant.example.com/payment/result" {
+ t.Fatalf("return_url = %q", param.ReturnURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_101",
+ Amount: "18.00",
+ Subject: "Balance recharge",
+ IsMobile: true,
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if wapPayCalls != 1 {
+ t.Fatalf("wap pay calls = %d, want 1", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for mobile wap pay")
+ }
+}
+
+func TestAlipayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &Alipay{
+ config: map[string]string{
+ "appId": "2021001234567890",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["app_id"] != "2021001234567890" {
+ t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
+ }
+}
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index e33a567d..37bd38b2 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -59,6 +59,17 @@ func (e *EasyPay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
}
+func (e *EasyPay) MerchantIdentityMetadata() map[string]string {
+ if e == nil {
+ return nil
+ }
+ pid := strings.TrimSpace(e.config["pid"])
+ if pid == "" {
+ return nil
+ }
+ return map[string]string{"pid": pid}
+}
+
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
@@ -178,7 +189,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
status = payment.ProviderStatusPaid
}
amount, _ := strconv.ParseFloat(resp.Money, 64)
- return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: tradeNo,
+ Status: status,
+ Amount: amount,
+ Metadata: e.MerchantIdentityMetadata(),
+ }, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
@@ -203,9 +219,17 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
status = payment.ProviderStatusSuccess
}
amount, _ := strconv.ParseFloat(params["money"], 64)
+
+ metadata := e.MerchantIdentityMetadata()
+ if pid := strings.TrimSpace(params["pid"]); pid != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["pid"] = pid
+ }
return &payment.PaymentNotification{
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
- Amount: amount, Status: status, RawData: rawBody,
+ Amount: amount, Status: status, RawData: rawBody, Metadata: metadata,
}, nil
}
diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go
index 146a6fa1..8328d294 100644
--- a/backend/internal/payment/provider/easypay_sign_test.go
+++ b/backend/internal/payment/provider/easypay_sign_test.go
@@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
}
}
+
+func TestEasyPayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &EasyPay{
+ config: map[string]string{
+ "pid": "1001",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["pid"] != "1001" {
+ t.Fatalf("pid = %q, want %q", metadata["pid"], "1001")
+ }
+}
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 0b41c4fb..e6291dd3 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -3,22 +3,24 @@ package provider
import (
"bytes"
"context"
- "crypto/rsa"
"fmt"
"io"
- "log/slog"
"net/http"
+ "net/url"
+ "strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
"github.com/wechatpay-apiv3/wechatpay-go/utils"
@@ -26,8 +28,23 @@ import (
// WeChat Pay constants.
const (
- wxpayCurrency = "CNY"
- wxpayH5Type = "Wap"
+ wxpayCurrency = "CNY"
+ wxpayH5Type = "Wap"
+ wxpayResultPath = "/payment/result"
+)
+
+const (
+ wxpayMetadataAppID = "appid"
+ wxpayMetadataMerchantID = "mchid"
+ wxpayMetadataCurrency = "currency"
+ wxpayMetadataTradeState = "trade_state"
+)
+
+// WeChat Pay create-payment modes.
+const (
+ wxpayModeNative = "native"
+ wxpayModeH5 = "h5"
+ wxpayModeJSAPI = "jsapi"
)
// WeChat Pay trade states.
@@ -43,9 +60,16 @@ const (
wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
)
-// WeChat Pay error codes.
-const (
- wxpayErrNoAuth = "NO_AUTH"
+var (
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ return svc.Prepay(ctx, req)
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ return svc.Prepay(ctx, req)
+ }
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ return svc.PrepayWithRequestPayment(ctx, req)
+ }
)
type Wxpay struct {
@@ -56,15 +80,35 @@ type Wxpay struct {
notifyHandler *notify.Handler
}
+const wxpayAPIv3KeyLength = 32
+
func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
- required := []string{"appId", "mchId", "privateKey", "apiV3Key", "publicKey", "publicKeyId", "certSerial"}
+ // All fields are required. Platform-certificate mode is intentionally unsupported —
+ // WeChat has been migrating all merchants to the pubkey verifier since 2024-10,
+ // and newly-provisioned merchants cannot download platform certificates at all.
+ required := []string{"appId", "mchId", "privateKey", "apiV3Key", "certSerial", "publicKey", "publicKeyId"}
for _, k := range required {
if config[k] == "" {
- return nil, fmt.Errorf("wxpay config missing required key: %s", k)
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_MISSING_KEY", "missing_required_key").
+ WithMetadata(map[string]string{"key": k})
}
}
- if len(config["apiV3Key"]) != 32 {
- return nil, fmt.Errorf("wxpay apiV3Key must be exactly 32 bytes, got %d", len(config["apiV3Key"]))
+ if len(config["apiV3Key"]) != wxpayAPIv3KeyLength {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY_LENGTH", "invalid_key_length").
+ WithMetadata(map[string]string{
+ "key": "apiV3Key",
+ "expected": strconv.Itoa(wxpayAPIv3KeyLength),
+ "actual": strconv.Itoa(len(config["apiV3Key"])),
+ })
+ }
+ // Parse PEMs eagerly so malformed keys surface at save time, not at order creation.
+ if _, err := utils.LoadPrivateKey(formatPEM(config["privateKey"], "PRIVATE KEY")); err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "privateKey"})
+ }
+ if _, err := utils.LoadPublicKey(formatPEM(config["publicKey"], "PUBLIC KEY")); err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "publicKey"})
}
return &Wxpay{instanceID: instanceID, config: config}, nil
}
@@ -75,6 +119,16 @@ func (w *Wxpay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeWxpay}
}
+// ResolveWxpayJSAPIAppID returns the AppID that JSAPI prepay will use for a
+// given provider config. A dedicated MP AppID takes precedence over the base
+// merchant AppID.
+func ResolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
func formatPEM(key, keyType string) string {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "-----BEGIN") {
@@ -89,14 +143,19 @@ func (w *Wxpay) ensureClient() (*core.Client, error) {
if w.coreClient != nil {
return w.coreClient, nil
}
- privateKey, publicKey, err := w.loadKeyPair()
+ privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
if err != nil {
- return nil, err
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "privateKey"})
+ }
+ publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
+ if err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "publicKey"})
}
- certSerial := w.config["certSerial"]
verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey)
client, err := core.NewClient(context.Background(),
- option.WithMerchantCredential(w.config["mchId"], certSerial, privateKey),
+ option.WithMerchantCredential(w.config["mchId"], w.config["certSerial"], privateKey),
option.WithVerifier(verifier))
if err != nil {
return nil, fmt.Errorf("wxpay init client: %w", err)
@@ -110,18 +169,6 @@ func (w *Wxpay) ensureClient() (*core.Client, error) {
return w.coreClient, nil
}
-func (w *Wxpay) loadKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) {
- privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
- if err != nil {
- return nil, nil, fmt.Errorf("wxpay load private key: %w", err)
- }
- publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
- if err != nil {
- return nil, nil, fmt.Errorf("wxpay load public key: %w", err)
- }
- return privateKey, publicKey, nil
-}
-
func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := w.ensureClient()
if err != nil {
@@ -139,30 +186,61 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
if err != nil {
return nil, fmt.Errorf("wxpay create payment: %w", err)
}
- if req.IsMobile && req.ClientIP != "" {
- resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true)
- if err == nil {
- return resp, nil
- }
- if !strings.Contains(err.Error(), wxpayErrNoAuth) {
- return nil, err
- }
- slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID)
+
+ mode, err := resolveWxpayCreateMode(req)
+ if err != nil {
+ return nil, err
+ }
+ switch mode {
+ case wxpayModeJSAPI:
+ return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen)
+ case wxpayModeH5:
+ return w.prepayH5(ctx, client, req, notifyURL, totalFen)
+ case wxpayModeNative:
+ return w.prepayNative(ctx, client, req, notifyURL, totalFen)
+ default:
+ return nil, fmt.Errorf("wxpay create payment: unsupported mode %q", mode)
}
- return w.createOrder(ctx, client, req, notifyURL, totalFen, false)
}
-func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) {
- if useH5 {
- return w.prepayH5(ctx, c, req, notifyURL, totalFen)
+func (w *Wxpay) prepayJSAPI(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
+ svc := jsapi.JsapiApiService{Client: c}
+ cur := wxpayCurrency
+ appID := ResolveWxpayJSAPIAppID(w.config)
+ prepayReq := jsapi.PrepayRequest{
+ Appid: core.String(appID),
+ Mchid: core.String(w.config["mchId"]),
+ Description: core.String(req.Subject),
+ OutTradeNo: core.String(req.OrderID),
+ NotifyUrl: core.String(notifyURL),
+ Amount: &jsapi.Amount{Total: core.Int64(totalFen), Currency: &cur},
+ Payer: &jsapi.Payer{Openid: core.String(strings.TrimSpace(req.OpenID))},
}
- return w.prepayNative(ctx, c, req, notifyURL, totalFen)
+ if clientIP := strings.TrimSpace(req.ClientIP); clientIP != "" {
+ prepayReq.SceneInfo = &jsapi.SceneInfo{PayerClientIp: core.String(clientIP)}
+ }
+ resp, _, err := wxpayJSAPIPrepayWithRequestPayment(ctx, svc, prepayReq)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay jsapi prepay: %w", err)
+ }
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ ResultType: payment.CreatePaymentResultJSAPIReady,
+ JSAPI: &payment.WechatJSAPIPayload{
+ AppID: wxSV(resp.Appid),
+ TimeStamp: wxSV(resp.TimeStamp),
+ NonceStr: wxSV(resp.NonceStr),
+ Package: wxSV(resp.Package),
+ SignType: wxSV(resp.SignType),
+ PaySign: wxSV(resp.PaySign),
+ },
+ }, nil
}
func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := native.NativeApiService{Client: c}
cur := wxpayCurrency
- resp, _, err := svc.Prepay(ctx, native.PrepayRequest{
+ resp, _, err := wxpayNativePrepay(ctx, svc, native.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
@@ -181,13 +259,12 @@ func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.Cr
func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := h5.H5ApiService{Client: c}
cur := wxpayCurrency
- tp := wxpayH5Type
- resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{
+ resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
- SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}},
+ SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)},
})
if err != nil {
return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
@@ -196,9 +273,77 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create
if resp.H5Url != nil {
h5URL = *resp.H5Url
}
+ h5URL, err = appendWxpayRedirectURL(h5URL, req)
+ if err != nil {
+ return nil, err
+ }
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
}
+func buildWxpayH5Info(config map[string]string) *h5.H5Info {
+ tp := wxpayH5Type
+ info := &h5.H5Info{Type: &tp}
+ if appName := strings.TrimSpace(config["h5AppName"]); appName != "" {
+ info.AppName = core.String(appName)
+ }
+ if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" {
+ info.AppUrl = core.String(appURL)
+ }
+ return info
+}
+
+func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) {
+ if strings.TrimSpace(req.OpenID) != "" {
+ return wxpayModeJSAPI, nil
+ }
+ if req.IsMobile {
+ if strings.TrimSpace(req.ClientIP) == "" {
+ return "", fmt.Errorf("wxpay H5 payment requires client IP")
+ }
+ return wxpayModeH5, nil
+ }
+ return wxpayModeNative, nil
+}
+
+func appendWxpayRedirectURL(h5URL string, req payment.CreatePaymentRequest) (string, error) {
+ h5URL = strings.TrimSpace(h5URL)
+ returnURL := strings.TrimSpace(req.ReturnURL)
+ if h5URL == "" || returnURL == "" {
+ return h5URL, nil
+ }
+
+ redirectURL, err := buildWxpayResultURL(returnURL, req)
+ if err != nil {
+ return "", err
+ }
+
+ sep := "&"
+ if !strings.Contains(h5URL, "?") {
+ sep = "?"
+ }
+ return h5URL + sep + "redirect_url=" + url.QueryEscape(redirectURL), nil
+}
+
+func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (string, error) {
+ u, err := url.Parse(returnURL)
+ if err != nil || !u.IsAbs() || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
+ return "", fmt.Errorf("return URL must be an absolute http(s) URL")
+ }
+
+ values := u.Query()
+ values.Set("out_trade_no", strings.TrimSpace(req.OrderID))
+ if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" {
+ values.Set("payment_type", paymentType)
+ }
+ if strings.TrimSpace(u.Path) == "" {
+ u.Path = wxpayResultPath
+ }
+ u.RawPath = ""
+ u.RawQuery = values.Encode()
+ u.Fragment = ""
+ return u.String(), nil
+}
+
func wxSV(s *string) string {
if s == nil {
return ""
@@ -219,6 +364,32 @@ func mapWxState(s string) string {
}
}
+func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
+ if tx == nil {
+ return nil
+ }
+
+ metadata := map[string]string{}
+ if appID := wxSV(tx.Appid); appID != "" {
+ metadata[wxpayMetadataAppID] = appID
+ }
+ if merchantID := wxSV(tx.Mchid); merchantID != "" {
+ metadata[wxpayMetadataMerchantID] = merchantID
+ }
+ if tradeState := wxSV(tx.TradeState); tradeState != "" {
+ metadata[wxpayMetadataTradeState] = tradeState
+ }
+ if tx.Amount != nil {
+ if currency := wxSV(tx.Amount.Currency); currency != "" {
+ metadata[wxpayMetadataCurrency] = currency
+ }
+ }
+ if len(metadata) == 0 {
+ return nil
+ }
+ return metadata
+}
+
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient()
if err != nil {
@@ -243,7 +414,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO
if tx.SuccessTime != nil {
pa = *tx.SuccessTime
}
- return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: id,
+ Status: mapWxState(wxSV(tx.TradeState)),
+ Amount: amt,
+ PaidAt: pa,
+ Metadata: buildWxpayTransactionMetadata(tx),
+ }, nil
}
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
@@ -275,7 +452,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers
}
return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
- Amount: amt, Status: st, RawData: rawBody,
+ Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
}, nil
}
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index b8b99537..e8ac5e54 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -3,12 +3,44 @@
package provider
import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/wechatpay-apiv3/wechatpay-go/core"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
)
+// generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings.
+// The wechatpay-go SDK expects PKCS8 private keys and PKIX public keys.
+func generateTestKeyPair(t *testing.T) (privPEM, pubPEM string) {
+ t.Helper()
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("generate rsa key: %v", err)
+ }
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ if err != nil {
+ t.Fatalf("marshal pkcs8: %v", err)
+ }
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ if err != nil {
+ t.Fatalf("marshal pkix: %v", err)
+ }
+ return string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}))
+}
+
func TestMapWxState(t *testing.T) {
t.Parallel()
@@ -96,6 +128,33 @@ func TestWxSV(t *testing.T) {
}
}
+func TestBuildWxpayTransactionMetadata(t *testing.T) {
+ t.Parallel()
+
+ tx := &payments.Transaction{
+ Appid: strPtr("wx-app-id"),
+ Mchid: strPtr("mch-id"),
+ TradeState: strPtr(wxpayTradeStateSuccess),
+ Amount: &payments.TransactionAmount{
+ Currency: strPtr(wxpayCurrency),
+ },
+ }
+
+ metadata := buildWxpayTransactionMetadata(tx)
+ if metadata[wxpayMetadataAppID] != "wx-app-id" {
+ t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
+ }
+ if metadata[wxpayMetadataMerchantID] != "mch-id" {
+ t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
+ }
+ if metadata[wxpayMetadataCurrency] != wxpayCurrency {
+ t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
+ }
+ if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
+ t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
+ }
+}
+
func strPtr(s string) *string {
return &s
}
@@ -149,13 +208,14 @@ func TestFormatPEM(t *testing.T) {
func TestNewWxpay(t *testing.T) {
t.Parallel()
+ privPEM, pubPEM := generateTestKeyPair(t)
validConfig := map[string]string{
"appId": "wx1234567890",
"mchId": "1234567890",
- "privateKey": "fake-private-key",
+ "privateKey": privPEM,
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
- "publicKey": "fake-public-key",
- "publicKeyId": "key-id-001",
+ "publicKey": pubPEM,
+ "publicKeyId": "PUB_KEY_ID_TEST",
"certSerial": "SERIAL001",
}
@@ -206,6 +266,12 @@ func TestNewWxpay(t *testing.T) {
wantErr: true,
errSubstr: "apiV3Key",
},
+ {
+ name: "missing certSerial",
+ config: withOverride(map[string]string{"certSerial": ""}),
+ wantErr: true,
+ errSubstr: "certSerial",
+ },
{
name: "missing publicKey",
config: withOverride(map[string]string{"publicKey": ""}),
@@ -218,17 +284,29 @@ func TestNewWxpay(t *testing.T) {
wantErr: true,
errSubstr: "publicKeyId",
},
+ {
+ name: "malformed privateKey PEM",
+ config: withOverride(map[string]string{"privateKey": "not-a-valid-pem"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY",
+ },
+ {
+ name: "malformed publicKey PEM",
+ config: withOverride(map[string]string{"publicKey": "not-a-valid-pem"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY",
+ },
{
name: "apiV3Key too short",
config: withOverride(map[string]string{"apiV3Key": "short"}),
wantErr: true,
- errSubstr: "exactly 32 bytes",
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
},
{
name: "apiV3Key too long",
config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes
wantErr: true,
- errSubstr: "exactly 32 bytes",
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
},
}
@@ -257,3 +335,375 @@ func TestNewWxpay(t *testing.T) {
})
}
}
+
+func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) {
+ t.Parallel()
+
+ resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{
+ OrderID: "sub2_42",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("buildWxpayResultURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(resultURL)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if parsed.Path != wxpayResultPath {
+ t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath)
+ }
+ if query.Get("resume_token") != "resume-42" {
+ t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42")
+ }
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42")
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42")
+ }
+}
+
+func TestResolveWxpayJSAPIAppID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ config map[string]string
+ want string
+ }{
+ {
+ name: "prefers dedicated mp app id",
+ config: map[string]string{
+ "mpAppId": "wx-mp-app",
+ "appId": "wx-merchant-app",
+ },
+ want: "wx-mp-app",
+ },
+ {
+ name: "falls back to merchant app id",
+ config: map[string]string{
+ "appId": "wx-merchant-app",
+ },
+ want: "wx-merchant-app",
+ },
+ {
+ name: "missing app ids returns empty",
+ config: map[string]string{},
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want {
+ t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestResolveWxpayCreateMode(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ req payment.CreatePaymentRequest
+ wantMode string
+ wantErr string
+ }{
+ {
+ name: "desktop uses native",
+ req: payment.CreatePaymentRequest{},
+ wantMode: wxpayModeNative,
+ },
+ {
+ name: "mobile uses h5 when client ip is present",
+ req: payment.CreatePaymentRequest{
+ IsMobile: true,
+ ClientIP: "203.0.113.10",
+ },
+ wantMode: wxpayModeH5,
+ },
+ {
+ name: "mobile without client ip returns clear error",
+ req: payment.CreatePaymentRequest{
+ IsMobile: true,
+ },
+ wantErr: "requires client IP",
+ },
+ {
+ name: "openid uses jsapi mode",
+ req: payment.CreatePaymentRequest{
+ OpenID: "openid-123",
+ },
+ wantMode: wxpayModeJSAPI,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := resolveWxpayCreateMode(tt.req)
+ if tt.wantErr != "" {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), tt.wantErr) {
+ t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != tt.wantMode {
+ t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode)
+ }
+ })
+ }
+}
+
+func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ if got := wxSV(req.Payer.Openid); got != "openid-123" {
+ t.Fatalf("openid = %q, want %q", got, "openid-123")
+ }
+ if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10")
+ }
+ return &jsapi.PrepayWithRequestPaymentResponse{
+ Appid: core.String("wx123"),
+ TimeStamp: core.String("1712345678"),
+ NonceStr: core.String("nonce-123"),
+ Package: core.String("prepay_id=wx_prepay_123"),
+ SignType: core.String("RSA"),
+ PaySign: core.String("signed-payload"),
+ }, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ return &h5.PrepayResponse{}, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_88",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ NotifyURL: "https://merchant.example/payment/notify",
+ OpenID: "openid-123",
+ ClientIP: "203.0.113.10",
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 1 {
+ t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 0 {
+ t.Fatalf("h5 prepay calls = %d, want 0", h5Calls)
+ }
+ if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
+ }
+ if resp.JSAPI == nil {
+ t.Fatal("expected jsapi payload, got nil")
+ }
+ if resp.JSAPI.AppID != "wx123" {
+ t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123")
+ }
+ if resp.JSAPI.TimeStamp != "1712345678" {
+ t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678")
+ }
+ if resp.JSAPI.NonceStr != "nonce-123" {
+ t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123")
+ }
+ if resp.JSAPI.Package != "prepay_id=wx_prepay_123" {
+ t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123")
+ }
+ if resp.JSAPI.SignType != "RSA" {
+ t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA")
+ }
+ if resp.JSAPI.PaySign != "signed-payload" {
+ t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload")
+ }
+}
+
+func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ if req.SceneInfo == nil {
+ t.Fatal("expected scene_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10")
+ }
+ if req.SceneInfo.H5Info == nil {
+ t.Fatal("expected scene_info.h5_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type {
+ t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type)
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" {
+ t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" {
+ t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com")
+ }
+ return &h5.PrepayResponse{
+ H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ "h5AppName": "Sub2API",
+ "h5AppUrl": "https://app.example.com",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_99",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if !strings.Contains(resp.PayURL, "redirect_url=") {
+ t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
+ }
+}
+
+func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ return nil, nil, errors.New("NO_AUTH")
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{
+ CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err == nil {
+ t.Fatal("expected no-auth error, got nil")
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if !strings.Contains(err.Error(), "NO_AUTH") {
+ t.Fatalf("error = %v, want NO_AUTH", err)
+ }
+}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
index 5d613a4a..e7ac6727 100644
--- a/backend/internal/payment/types.go
+++ b/backend/internal/payment/types.go
@@ -101,34 +101,69 @@ type CreatePaymentRequest struct {
Subject string // Product description
NotifyURL string // Webhook callback URL
ReturnURL string // Browser redirect URL after payment
+ OpenID string // WeChat JSAPI payer OpenID when available
ClientIP string // Payer's IP address
IsMobile bool // Whether the request comes from a mobile device
InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
}
+// CreatePaymentResultType describes the shape of the create-payment result.
+type CreatePaymentResultType = string
+
+const (
+ CreatePaymentResultOrderCreated CreatePaymentResultType = "order_created"
+ CreatePaymentResultOAuthRequired CreatePaymentResultType = "oauth_required"
+ CreatePaymentResultJSAPIReady CreatePaymentResultType = "jsapi_ready"
+)
+
+// WechatOAuthInfo describes the next step when WeChat OAuth is required before payment.
+type WechatOAuthInfo struct {
+ AuthorizeURL string `json:"authorize_url,omitempty"`
+ AppID string `json:"appid,omitempty"`
+ OpenID string `json:"openid,omitempty"`
+ Scope string `json:"scope,omitempty"`
+ State string `json:"state,omitempty"`
+ RedirectURL string `json:"redirect_url,omitempty"`
+}
+
+// WechatJSAPIPayload contains the fields the frontend needs to invoke WeChat JSAPI payment.
+type WechatJSAPIPayload struct {
+ AppID string `json:"appId,omitempty"`
+ TimeStamp string `json:"timeStamp,omitempty"`
+ NonceStr string `json:"nonceStr,omitempty"`
+ Package string `json:"package,omitempty"`
+ SignType string `json:"signType,omitempty"`
+ PaySign string `json:"paySign,omitempty"`
+}
+
// CreatePaymentResponse is returned after successfully initiating a payment.
type CreatePaymentResponse struct {
- TradeNo string // Third-party transaction ID
- PayURL string // H5 payment URL (alipay/wxpay)
- QRCode string // QR code content for scanning
- ClientSecret string // Stripe PaymentIntent client secret
+ TradeNo string // Third-party transaction ID
+ PayURL string // H5 payment URL (alipay/wxpay)
+ QRCode string // QR code content for scanning
+ ClientSecret string // Stripe PaymentIntent client secret
+ ResultType CreatePaymentResultType // Typed result contract for frontend flows
+ OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
+ JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
}
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
- TradeNo string
- Status string // "pending", "paid", "failed", "refunded"
- Amount float64 // Amount in CNY
- PaidAt string // RFC3339 timestamp or empty
+ TradeNo string
+ Status string // "pending", "paid", "failed", "refunded"
+ Amount float64 // Amount in CNY
+ PaidAt string // RFC3339 timestamp or empty
+ Metadata map[string]string
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
- TradeNo string
- OrderID string
- Amount float64
- Status string // "success" or "failed"
- RawData string // Raw notification body for audit
+ TradeNo string
+ OrderID string
+ Amount float64
+ Status string // "success" or "failed"
+ RawData string // Raw notification body for audit
+ Metadata map[string]string
}
// RefundRequest contains the parameters for requesting a refund.
@@ -179,3 +214,9 @@ type CancelableProvider interface {
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
+
+// MerchantIdentityProvider exposes the current non-sensitive merchant identity
+// derived from provider configuration for snapshot consistency checks.
+type MerchantIdentityProvider interface {
+ MerchantIdentityMetadata() map[string]string
+}
diff --git a/backend/internal/payment/wire.go b/backend/internal/payment/wire.go
index 9717465d..4b7f422d 100644
--- a/backend/internal/payment/wire.go
+++ b/backend/internal/payment/wire.go
@@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
+ "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
- if cfg.Totp.EncryptionKey == "" {
+ if cfg == nil {
+ slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable")
+ return nil, nil
+ }
+ keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey)
+ if keyHex == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
- key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
+ // Reject auto-generated TOTP keys for payment signing.
+ // They change across restarts/instances and can silently break resume-token flows.
+ if !cfg.Totp.EncryptionKeyConfigured {
+ slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens")
+ return nil, nil
+ }
+ key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}
diff --git a/backend/internal/payment/wire_test.go b/backend/internal/payment/wire_test.go
new file mode 100644
index 00000000..1b360f89
--- /dev/null
+++ b/backend/internal/payment/wire_test.go
@@ -0,0 +1,62 @@
+package payment
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: strings.Repeat("a", 64),
+ EncryptionKeyConfigured: false,
+ },
+ }
+
+ key, err := ProvideEncryptionKey(cfg)
+ if err != nil {
+ t.Fatalf("ProvideEncryptionKey returned error: %v", err)
+ }
+ if len(key) != 0 {
+ t.Fatalf("encryption key len = %d, want 0", len(key))
+ }
+}
+
+func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ EncryptionKeyConfigured: true,
+ },
+ }
+
+ key, err := ProvideEncryptionKey(cfg)
+ if err != nil {
+ t.Fatalf("ProvideEncryptionKey returned error: %v", err)
+ }
+ if len(key) != 32 {
+ t.Fatalf("encryption key len = %d, want 32", len(key))
+ }
+}
+
+func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: "abcd",
+ EncryptionKeyConfigured: true,
+ },
+ }
+
+ _, err := ProvideEncryptionKey(cfg)
+ if err == nil {
+ t.Fatal("expected error for invalid key length")
+ }
+}
diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go
index 49e38bf8..60ffefb3 100644
--- a/backend/internal/pkg/openai/constants.go
+++ b/backend/internal/pkg/openai/constants.go
@@ -17,16 +17,12 @@ type Model struct {
var DefaultModels = []Model{
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
{ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"},
- {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
- {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
- {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
- {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
- {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
- {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
- {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
+ {ID: "gpt-image-1", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1"},
+ {ID: "gpt-image-1.5", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1.5"},
+ {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"},
}
// DefaultModelIDs returns the default model ID list
@@ -39,7 +35,7 @@ func DefaultModelIDs() []string {
}
// DefaultTestModel default model for testing OpenAI accounts
-const DefaultTestModel = "gpt-5.1-codex"
+const DefaultTestModel = "gpt-5.4"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 24115c33..78f739ac 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
+ if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil {
+ return err
+ }
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
diff --git a/backend/internal/repository/announcement_read_repo.go b/backend/internal/repository/announcement_read_repo.go
index 2dc346b1..5268ec45 100644
--- a/backend/internal/repository/announcement_read_repo.go
+++ b/backend/internal/repository/announcement_read_repo.go
@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
- return client.AnnouncementRead.Create().
+ err := client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
+ if isSQLNoRowsError(err) {
+ return nil
+ }
+ return err
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 38ea9bde..36d80309 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
+ user.FieldSignupSource,
+ user.FieldLastLoginAt,
+ user.FieldLastActiveAt,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
+ SignupSource: u.SignupSource,
+ LastLoginAt: u.LastLoginAt,
+ LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
new file mode 100644
index 00000000..7e34777a
--- /dev/null
+++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
@@ -0,0 +1,80 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql")
+ migration108SQL, err := os.ReadFile(migration108Path)
+ require.NoError(t, err)
+
+ migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
+ migration108aSQL, err := os.ReadFile(migration108aPath)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE;
+DROP TABLE IF EXISTS auth_identity_channels CASCADE;
+DROP TABLE IF EXISTS identity_adoption_decisions CASCADE;
+DROP TABLE IF EXISTS pending_auth_sessions CASCADE;
+DROP TABLE IF EXISTS auth_identities CASCADE;
+
+ALTER TABLE users
+ DROP COLUMN IF EXISTS signup_source,
+ DROP COLUMN IF EXISTS last_login_at,
+ DROP COLUMN IF EXISTS last_active_at;
+`)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108aSQL))
+ require.NoError(t, err)
+
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&userID))
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ var reportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(userID, 10)).Scan(&reportCount))
+ require.Equal(t, 1, reportCount)
+
+ var reportTypeLimit int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeLimit))
+ require.GreaterOrEqual(t, reportTypeLimit, 45)
+
+ require.NotZero(t, userID)
+}
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
new file mode 100644
index 00000000..e64934c5
--- /dev/null
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -0,0 +1,959 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoUserID))
+
+ var wechatUnionUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionUserID))
+
+ var wechatOpenIDOnlyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
+
+ var syntheticAuthIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
+RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
+
+ var linuxDoLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoUserID).Scan(&linuxDoLegacyID))
+
+ var wechatUnionLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
+RETURNING id
+`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
+
+ var wechatOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
+RETURNING id
+`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-user-1'
+`, linuxDoUserID).Scan(&linuxDoCount))
+ require.Equal(t, 1, linuxDoCount)
+
+ var wechatSubject string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT provider_subject
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-1'
+`, wechatUnionUserID).Scan(&wechatSubject))
+ require.Equal(t, "union-1", wechatSubject)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels channel
+JOIN auth_identities ai ON ai.id = channel.identity_id
+WHERE ai.user_id = $1
+ AND channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = 'oa'
+ AND channel.channel_app_id = 'wx-app-1'
+ AND channel.channel_subject = 'openid-union-1'
+`, wechatUnionUserID).Scan(&wechatChannelCount))
+ require.Equal(t, 1, wechatChannelCount)
+
+ var legacyOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
+ require.Equal(t, 1, legacyOpenIDOnlyReportCount)
+
+ var syntheticReviewCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
+ require.Equal(t, 1, syntheticReviewCount)
+
+ var unionLegacyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
+ require.Zero(t, unionLegacyReportCount)
+ require.NotZero(t, linuxDoLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoMalformedUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoMalformedUserID))
+
+ var linuxDoArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoArrayUserID))
+
+ var wechatUnionArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionArrayUserID))
+
+ var wechatOpenIDArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDArrayUserID))
+
+ var linuxDoMalformedLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
+RETURNING id
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
+
+ var linuxDoArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
+RETURNING id
+`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
+
+ var wechatUnionArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
+RETURNING id
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
+
+ var wechatOpenIDArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
+RETURNING id
+`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var linuxDoMalformedMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-malformed'
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
+ require.Equal(t, "object", linuxDoMalformedMetadataType)
+
+ var linuxDoArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-array'
+`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
+ require.Equal(t, "object", linuxDoArrayMetadataType)
+
+ var wechatUnionArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-array'
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
+ require.Equal(t, "object", wechatUnionArrayMetadataType)
+
+ var invalidJSONReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
+ require.Equal(t, "object", invalidJSONReportDetailsType)
+
+ var openIDOnlyReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
+ require.Equal(t, "object", openIDOnlyReportDetailsType)
+
+ var preservedArrayMetadataCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE id IN (
+ SELECT id
+ FROM auth_identities
+ WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
+ OR (user_id = $2 AND provider_subject = 'union-array')
+)
+ AND metadata ? '_legacy_metadata_raw_json'
+`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
+ require.Equal(t, 2, preservedArrayMetadataCount)
+
+ require.NotZero(t, linuxDoArrayLegacyID)
+ require.NotZero(t, wechatUnionArrayLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ userIDs := make([]int64, 0, 8)
+ for _, email := range []string{
+ "linuxdo-conflict-legacy@example.com",
+ "linuxdo-conflict-owner@example.com",
+ "wechat-conflict-legacy@example.com",
+ "wechat-conflict-owner@example.com",
+ "wechat-channel-legacy@example.com",
+ "wechat-channel-owner@example.com",
+ "linuxdo-invalid-json@example.com",
+ "wechat-openid-invalid-json@example.com",
+ } {
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ($1, 'hash', 'user', 'active', 0, 1)
+RETURNING id`, email).Scan(&userID))
+ userIDs = append(userIDs, userID)
+ }
+
+ linuxdoConflictLegacyUserID := userIDs[0]
+ linuxdoConflictOwnerUserID := userIDs[1]
+ wechatConflictLegacyUserID := userIDs[2]
+ wechatConflictOwnerUserID := userIDs[3]
+ wechatChannelLegacyUserID := userIDs[4]
+ wechatChannelOwnerUserID := userIDs[5]
+ linuxdoInvalidJSONUserID := userIDs[6]
+ wechatInvalidOpenIDUserID := userIDs[7]
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb)
+RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb)
+RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64)))
+
+ var wechatChannelOwnerIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64)))
+
+ var linuxdoConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}')
+RETURNING id
+`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID))
+
+ var wechatConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}')
+RETURNING id
+`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID))
+
+ var wechatChannelConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}')
+RETURNING id
+`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID))
+
+ var linuxdoInvalidJSONLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid')
+RETURNING id
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID))
+
+ var wechatInvalidOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid')
+RETURNING id
+`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxdoConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount))
+ require.Equal(t, 1, linuxdoConflictReportCount)
+
+ var wechatConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount))
+ require.Equal(t, 1, wechatConflictReportCount)
+
+ var channelConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_channel_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount))
+ require.Equal(t, 1, channelConflictReportCount)
+
+ var invalidJSONReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key IN ($1, $2)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount))
+ require.Equal(t, 2, invalidJSONReportCount)
+
+ var linuxdoInvalidIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-invalid-json'
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount))
+ require.Equal(t, 1, linuxdoInvalidIdentityCount)
+
+ var wechatOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount))
+ require.Equal(t, 1, wechatOpenIDOnlyReportCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoFirstUserID))
+
+ var linuxDoSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoSecondUserID))
+
+ var wechatFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatFirstUserID))
+
+ var wechatSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatSecondUserID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoFirstUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoSecondUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
+RETURNING id
+`, wechatFirstUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
+RETURNING id
+`, wechatSecondUserID).Scan(new(int64)))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-ambiguous-subject'
+`).Scan(&linuxDoIdentityCount))
+ require.Zero(t, linuxDoIdentityCount)
+
+ var wechatIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-ambiguous-subject'
+`).Scan(&wechatIdentityCount))
+ require.Zero(t, wechatIdentityCount)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels
+WHERE provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND channel = 'oa'
+ AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
+`).Scan(&wechatChannelCount))
+ require.Zero(t, wechatChannelCount)
+}
+
+func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoFirstUserID))
+
+ var linuxDoSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoSecondUserID))
+
+ var wechatFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatFirstUserID))
+
+ var wechatSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatSecondUserID))
+
+ var linuxDoFirstLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
+
+ var linuxDoSecondLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
+
+ var wechatFirstLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
+RETURNING id
+`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
+
+ var wechatSecondLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
+RETURNING id
+`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var identityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
+ OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
+`).Scan(&identityCount))
+ require.Zero(t, identityCount)
+
+ var conflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key IN ($1, $2, $3, $4)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
+ require.Equal(t, 4, conflictReportCount)
+
+ var winnerAttributedReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key IN ($1, $2, $3, $4)
+ AND details ->> 'existing_identity_id' IS NOT NULL
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
+ require.Zero(t, winnerAttributedReportCount)
+}
+
+func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
+ migration108aSQL, err := os.ReadFile(migration108aPath)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ _, err = tx.ExecContext(ctx, `
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(40);
+`)
+ require.NoError(t, err)
+
+ var oidcSyntheticUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&oidcSyntheticUserID))
+
+ var linuxdoLegacyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxdoLegacyUserID))
+
+ var invalidMetadataLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
+RETURNING id
+`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration108aSQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var reportTypeWidth int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeWidth))
+ require.Equal(t, 80, reportTypeWidth)
+
+ var oidcSyntheticRecoveryReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount))
+ require.Equal(t, 1, oidcSyntheticRecoveryReportCount)
+
+ var invalidMetadataReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount))
+ require.Equal(t, 1, invalidMetadataReportCount)
+}
+
+func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) {
+ t.Helper()
+
+ _, err := tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+`)
+ require.NoError(t, err)
+}
+
+func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) {
+ t.Helper()
+
+ _, err := tx.ExecContext(ctx, `
+TRUNCATE TABLE
+ auth_identity_channels,
+ identity_adoption_decisions,
+ pending_auth_sessions,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_provider_default_grants,
+ user_avatars,
+ user_external_identities,
+ users
+RESTART IDENTITY CASCADE;
+`)
+ require.NoError(t, err)
+}
diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go
index 129b6e41..6dbb9fbd 100644
--- a/backend/internal/repository/migrations_runner.go
+++ b/backend/internal/repository/migrations_runner.go
@@ -51,38 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
+const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
+const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
type migrationChecksumCompatibilityRule struct {
- fileChecksum string
- acceptedFileChecksums map[string]struct{}
- acceptedDBChecksum map[string]struct{}
+ fileChecksum string
+ acceptedDBChecksum map[string]struct{}
+ acceptedChecksums map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
-// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
+// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行,
+// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
- "054_drop_legacy_cache_columns.sql": {
- fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
- acceptedDBChecksum: map[string]struct{}{
- "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
- },
- },
- "061_add_usage_log_request_type.sql": {
- fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
- acceptedDBChecksum: map[string]struct{}{
- "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
- "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
- },
- },
- "082_create_gateway_debug_logs.sql": {
- fileChecksum: "b740d7274afbd37d4448e3a3a9aa1fb562181ded5d0319e47a6444187d22f6b1",
- acceptedFileChecksums: map[string]struct{}{
- "bf5348a22cf1f27c852096beb3583b67ec43819af82b2f9664397a5638e5b386": {},
- },
- acceptedDBChecksum: map[string]struct{}{
- "d00c2e69711cc0c006b0234566101d8639ba08db77283558f07e2ba412ec177d": {},
- },
- },
+ "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
+ "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
+ "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
+ "110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"),
+ "112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"),
+ "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
+ "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
+ "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"),
+ "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
+ "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
@@ -209,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if nonTx {
+ if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
+ return fmt.Errorf("prepare migration %s: %w", name, err)
+ }
+
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
@@ -258,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil
}
+func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
+ switch name {
+ case paymentOrdersOutTradeNoUniqueMigration:
+ return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
+ default:
+ return nil
+ }
+}
+
+func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
+ duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
+ if err != nil {
+ return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
+ }
+ if len(duplicates) > 0 {
+ return fmt.Errorf(
+ "duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
+ paymentOrdersOutTradeNoUniqueMigration,
+ strings.Join(duplicates, ", "),
+ )
+ }
+
+ invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
+ if err != nil {
+ return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
+ }
+ if !invalid {
+ return nil
+ }
+
+ if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
+ return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
+ }
+ return nil
+}
+
+func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
+ rows, err := db.QueryContext(ctx, `
+ SELECT out_trade_no, COUNT(*) AS duplicate_count
+ FROM payment_orders
+ WHERE out_trade_no <> ''
+ GROUP BY out_trade_no
+ HAVING COUNT(*) > 1
+ ORDER BY duplicate_count DESC, out_trade_no
+ LIMIT 5
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ _ = rows.Close()
+ }()
+
+ duplicates := make([]string, 0, 5)
+ for rows.Next() {
+ var outTradeNo string
+ var duplicateCount int
+ if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
+ return nil, err
+ }
+ duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return duplicates, nil
+}
+
+func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
+ var invalid bool
+ err := db.QueryRowContext(ctx, `
+ SELECT EXISTS (
+ SELECT 1
+ FROM pg_class idx
+ JOIN pg_namespace ns ON ns.oid = idx.relnamespace
+ JOIN pg_index i ON i.indexrelid = idx.oid
+ WHERE ns.nspname = 'public'
+ AND idx.relname = $1
+ AND NOT i.indisvalid
+ )
+ `, indexName).Scan(&invalid)
+ return invalid, err
+}
+
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
@@ -332,18 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
+func checksumSet(values ...string) map[string]struct{} {
+ out := make(map[string]struct{}, len(values))
+ for _, value := range values {
+ out[value] = struct{}{}
+ }
+ return out
+}
+
+func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule {
+ return migrationChecksumCompatibilityRule{
+ fileChecksum: fileChecksum,
+ acceptedDBChecksum: checksumSet(acceptedDBChecksums...),
+ acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...),
+ }
+}
+
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
- if rule.fileChecksum != fileChecksum {
- if _, ok := rule.acceptedFileChecksums[fileChecksum]; !ok {
- return false
- }
+ _, dbOK := rule.acceptedChecksums[dbChecksum]
+ if !dbOK {
+ return false
}
- _, ok = rule.acceptedDBChecksum[dbChecksum]
- return ok
+ _, fileOK := rule.acceptedChecksums[fileChecksum]
+ return fileOK
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go
index 6c3ad725..1fcb3be1 100644
--- a/backend/internal/repository/migrations_runner_checksum_test.go
+++ b/backend/internal/repository/migrations_runner_checksum_test.go
@@ -51,4 +51,114 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
)
require.False(t, ok)
})
+
+ t.Run("109历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("110历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "110_pending_auth_and_provider_default_grants.sql",
+ "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925",
+ "32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("112历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "112_add_payment_order_provider_key_snapshot.sql",
+ "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e",
+ "b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "115_auth_identity_legacy_external_backfill.sql",
+ "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
+ "022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "116_auth_identity_legacy_external_safety_reports.sql",
+ "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
+ "07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "119_enforce_payment_orders_out_trade_no_unique.sql",
+ "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
+ "0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) {
+ for _, dbChecksum := range []string{
+ "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb",
+ "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227",
+ } {
+ ok := isMigrationChecksumCompatible(
+ "118_wechat_dual_mode_and_auth_source_defaults.sql",
+ dbChecksum,
+ "b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0",
+ )
+ require.True(t, ok)
+ }
+ })
+
+ t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
+ for _, dbChecksum := range []string{
+ "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
+ "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
+ "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
+ } {
+ ok := isMigrationChecksumCompatible(
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
+ dbChecksum,
+ "34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
+ )
+ require.True(t, ok)
+ }
+ })
+
+ t.Run("119未知checksum不兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "119_enforce_payment_orders_out_trade_no_unique.sql",
+ "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+ )
+ require.False(t, ok)
+ })
}
diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go
index 4c6cdb60..f09c148b 100644
--- a/backend/internal/repository/migrations_runner_extra_test.go
+++ b/backend/internal/repository/migrations_runner_extra_test.go
@@ -99,6 +99,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
}
}
+func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
+ for _, name := range []string{
+ "109_auth_identity_compat_backfill.sql",
+ "110_pending_auth_and_provider_default_grants.sql",
+ "112_add_payment_order_provider_key_snapshot.sql",
+ "115_auth_identity_legacy_external_backfill.sql",
+ "116_auth_identity_legacy_external_safety_reports.sql",
+ "118_wechat_dual_mode_and_auth_source_defaults.sql",
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
+ "123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
+ } {
+ rule, ok := migrationChecksumCompatibilityRules[name]
+ require.Truef(t, ok, "missing compatibility rule for %s", name)
+ require.NotEmpty(t, rule.fileChecksum)
+ require.NotEmpty(t, rule.acceptedDBChecksum)
+ }
+}
+
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go
index db1183cd..b7cb396c 100644
--- a/backend/internal/repository/migrations_runner_notx_test.go
+++ b/backend/internal/repository/migrations_runner_notx_test.go
@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require.NoError(t, mock.ExpectationsWereMet())
}
+func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
+ WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "duplicate out_trade_no")
+ require.Contains(t, err.Error(), "dup-out-trade-no")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
+ WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("paymentorder_out_trade_no_unique").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go
index dd3019bb..eeee5c23 100644
--- a/backend/internal/repository/migrations_schema_integration_test.go
+++ b/backend/internal/repository/migrations_schema_integration_test.go
@@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
+func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
+ tx := testTx(t)
+
+ requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
+ requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
+ requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
+ requireConstraintDefinitionContains(
+ t,
+ tx,
+ "users",
+ "users_signup_source_check",
+ "signup_source",
+ "'email'",
+ "'linuxdo'",
+ "'wechat'",
+ "'oidc'",
+ )
+
+ requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
+ requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
+
+ requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
+ requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
+ requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
+}
+
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
@@ -106,6 +135,118 @@ SELECT EXISTS (
require.True(t, exists, "expected index %s on %s", index, table)
}
+func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
+ t.Helper()
+
+ var exists bool
+ err := tx.QueryRowContext(context.Background(), `
+SELECT EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = $1
+ AND indexname = $2
+)
+`, table, index).Scan(&exists)
+ require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
+ require.False(t, exists, "expected index %s on %s to be absent", index, table)
+}
+
+func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
+ t.Helper()
+
+ var (
+ unique bool
+ def string
+ )
+
+ err := tx.QueryRowContext(context.Background(), `
+SELECT
+ i.indisunique,
+ pg_get_indexdef(i.indexrelid)
+FROM pg_class idx
+JOIN pg_index i ON i.indexrelid = idx.oid
+JOIN pg_class tbl ON tbl.oid = i.indrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+WHERE ns.nspname = 'public'
+ AND tbl.relname = $1
+ AND idx.relname = $2
+`, table, index).Scan(&unique, &def)
+ require.NoError(t, err, "query index definition for %s.%s", table, index)
+ require.True(t, unique, "expected index %s on %s to be unique", index, table)
+
+ for _, fragment := range fragments {
+ require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
+ }
+}
+
+func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
+ t.Helper()
+
+ var actual string
+ err := tx.QueryRowContext(context.Background(), `
+SELECT CASE c.confdeltype
+ WHEN 'a' THEN 'NO ACTION'
+ WHEN 'r' THEN 'RESTRICT'
+ WHEN 'c' THEN 'CASCADE'
+ WHEN 'n' THEN 'SET NULL'
+ WHEN 'd' THEN 'SET DEFAULT'
+END
+FROM pg_constraint c
+JOIN pg_class tbl ON tbl.oid = c.conrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
+JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
+WHERE ns.nspname = 'public'
+ AND c.contype = 'f'
+ AND tbl.relname = $1
+ AND attr.attname = $2
+ AND ref_tbl.relname = $3
+LIMIT 1
+`, table, column, refTable).Scan(&actual)
+ require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
+ require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
+}
+
+func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
+ t.Helper()
+
+ var def string
+ err := tx.QueryRowContext(context.Background(), `
+SELECT pg_get_constraintdef(c.oid)
+FROM pg_constraint c
+JOIN pg_class tbl ON tbl.oid = c.conrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+WHERE ns.nspname = 'public'
+ AND tbl.relname = $1
+ AND c.conname = $2
+`, table, constraint).Scan(&def)
+ require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
+
+ for _, fragment := range fragments {
+ require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
+ }
+}
+
+func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
+ t.Helper()
+
+ var columnDefault sql.NullString
+ err := tx.QueryRowContext(context.Background(), `
+SELECT column_default
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = $1
+ AND column_name = $2
+`, table, column).Scan(&columnDefault)
+ require.NoError(t, err, "query column_default for %s.%s", table, column)
+ require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
+
+ for _, fragment := range fragments {
+ require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
+ }
+}
+
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go
new file mode 100644
index 00000000..b2b03746
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo.go
@@ -0,0 +1,880 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "hash/fnv"
+ "reflect"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+ "unsafe"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+var (
+ ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_OWNERSHIP_CONFLICT",
+ "auth identity already belongs to another user",
+ )
+ ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
+ "auth identity channel already belongs to another user",
+ )
+ ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
+ "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
+ "auth identity channel provider must match canonical identity",
+ )
+)
+
+type ProviderGrantReason string
+
+const (
+ ProviderGrantReasonSignup ProviderGrantReason = "signup"
+ ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
+)
+
+type AuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type AuthIdentityChannelKey struct {
+ ProviderType string
+ ProviderKey string
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+}
+
+type CreateAuthIdentityInput struct {
+ UserID int64
+ Canonical AuthIdentityKey
+ Channel *AuthIdentityChannelKey
+ Issuer *string
+ VerifiedAt *time.Time
+ Metadata map[string]any
+ ChannelMetadata map[string]any
+}
+
+type BindAuthIdentityInput = CreateAuthIdentityInput
+
+type CreateAuthIdentityResult struct {
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
+ if r == nil || r.Identity == nil {
+ return AuthIdentityKey{}
+ }
+ return AuthIdentityKey{
+ ProviderType: r.Identity.ProviderType,
+ ProviderKey: r.Identity.ProviderKey,
+ ProviderSubject: r.Identity.ProviderSubject,
+ }
+}
+
+func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
+ if r == nil || r.Channel == nil {
+ return nil
+ }
+ return &AuthIdentityChannelKey{
+ ProviderType: r.Channel.ProviderType,
+ ProviderKey: r.Channel.ProviderKey,
+ Channel: r.Channel.Channel,
+ ChannelAppID: r.Channel.ChannelAppID,
+ ChannelSubject: r.Channel.ChannelSubject,
+ }
+}
+
+type UserAuthIdentityLookup struct {
+ User *dbent.User
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+type ProviderGrantRecordInput struct {
+ UserID int64
+ ProviderType string
+ GrantReason ProviderGrantReason
+}
+
+type IdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type sqlQueryExecutor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+}
+
+var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
+
+type scopedKeyLockRegistry struct {
+ mu sync.Mutex
+ locks map[string]*scopedKeyLockEntry
+}
+
+type scopedKeyLockEntry struct {
+ mu sync.Mutex
+ refs int
+}
+
+func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
+ return &scopedKeyLockRegistry{
+ locks: make(map[string]*scopedKeyLockEntry),
+ }
+}
+
+func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
+ normalized := normalizeLockKeys(keys...)
+ if len(normalized) == 0 {
+ return func() {}
+ }
+
+ entries := make([]*scopedKeyLockEntry, 0, len(normalized))
+ r.mu.Lock()
+ for _, key := range normalized {
+ entry := r.locks[key]
+ if entry == nil {
+ entry = &scopedKeyLockEntry{}
+ r.locks[key] = entry
+ }
+ entry.refs++
+ entries = append(entries, entry)
+ }
+ r.mu.Unlock()
+
+ for _, entry := range entries {
+ entry.mu.Lock()
+ }
+
+ return func() {
+ for i := len(entries) - 1; i >= 0; i-- {
+ entries[i].mu.Unlock()
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for idx, key := range normalized {
+ entry := entries[idx]
+ entry.refs--
+ if entry.refs == 0 {
+ delete(r.locks, key)
+ }
+ }
+ }
+}
+
+func normalizeLockKeys(keys ...string) []string {
+ if len(keys) == 0 {
+ return nil
+ }
+
+ deduped := make(map[string]struct{}, len(keys))
+ for _, key := range keys {
+ trimmed := strings.TrimSpace(key)
+ if trimmed == "" {
+ continue
+ }
+ deduped[trimmed] = struct{}{}
+ }
+ if len(deduped) == 0 {
+ return nil
+ }
+
+ normalized := make([]string, 0, len(deduped))
+ for key := range deduped {
+ normalized = append(normalized, key)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func advisoryLockHash(key string) int64 {
+ hasher := fnv.New64a()
+ _, _ = hasher.Write([]byte(key))
+ return int64(hasher.Sum64())
+}
+
+func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
+ release := repositoryScopedKeyLocks.lock(keys...)
+ normalized := normalizeLockKeys(keys...)
+ if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
+ return release, nil
+ }
+
+ for _, key := range normalized {
+ rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
+ if err != nil {
+ release()
+ return nil, err
+ }
+ _ = rows.Close()
+ }
+ return release, nil
+}
+
+func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if dbent.TxFromContext(ctx) != nil {
+ return fn(ctx)
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
+ client := clientFromContext(ctx, r.client)
+
+ create := client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt)
+
+ identity, err := create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
+}
+
+func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
+ identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
+ ).
+ WithUser().
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: identity.Edges.User,
+ Identity: identity,
+ }, nil
+}
+
+func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
+ channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: channel.Edges.Identity.Edges.User,
+ Identity: channel.Edges.Identity,
+ Channel: channel,
+ }, nil
+}
+
+func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(userID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ records := make([]service.UserAuthIdentityRecord, 0, len(identities))
+ for _, identity := range identities {
+ if identity == nil {
+ continue
+ }
+ records = append(records, service.UserAuthIdentityRecord{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: copyMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ })
+ }
+
+ return records, nil
+}
+
+func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
+ provider = strings.ToLower(strings.TrimSpace(provider))
+ if provider == "" || provider == "email" {
+ return service.ErrIdentityProviderInvalid
+ }
+
+ return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ identityIDs, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ(provider),
+ ).
+ IDs(txCtx)
+ if err != nil {
+ return err
+ }
+ if len(identityIDs) == 0 {
+ return nil
+ }
+
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return err
+ }
+ if _, err := client.AuthIdentityChannel.Delete().
+ Where(authidentitychannel.IdentityIDIn(identityIDs...)).
+ Exec(txCtx); err != nil {
+ return err
+ }
+ _, err = client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ(provider),
+ ).
+ Exec(txCtx)
+ return err
+ })
+}
+
+func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
+ var result *CreateAuthIdentityResult
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ canonical := input.Canonical
+
+ identityRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
+ authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
+ ).
+ All(txCtx)
+ if err != nil {
+ return err
+ }
+ identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
+ if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ if identity == nil {
+ identity, err = client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
+ update := client.AuthIdentity.UpdateOneID(identity.ID)
+ if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
+ update = update.SetProviderKey(targetProviderKey)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(copyMetadata(input.Metadata))
+ }
+ if input.Issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
+ }
+ if input.VerifiedAt != nil {
+ update = update.SetVerifiedAt(*input.VerifiedAt)
+ }
+ identity, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channelRecords, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
+ authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
+ ).
+ WithIdentity().
+ All(txCtx)
+ if err != nil {
+ return err
+ }
+ channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
+ if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
+ return ErrAuthIdentityChannelOwnershipConflict
+ }
+ if channel == nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
+ update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID)
+ if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
+ update = update.SetProviderKey(targetProviderKey)
+ }
+ if input.ChannelMetadata != nil {
+ update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
+ }
+ channel, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return []string{providerKey}
+ }
+ if providerType != "wechat" {
+ return []string{providerKey}
+ }
+ keys := []string{providerKey}
+ if !strings.EqualFold(providerKey, "wechat-main") {
+ keys = append(keys, "wechat-main")
+ }
+ if !strings.EqualFold(providerKey, "wechat") {
+ keys = append(keys, "wechat")
+ }
+ return keys
+}
+
+func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ existingKey = strings.TrimSpace(existingKey)
+ requestedKey = strings.TrimSpace(requestedKey)
+ if providerType != "wechat" {
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+ }
+ if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
+ return "wechat-main"
+ }
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+}
+
+func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerType != "wechat" {
+ return 0
+ }
+ switch {
+ case strings.EqualFold(providerKey, "wechat-main"):
+ return 0
+ case strings.EqualFold(providerKey, "wechat"):
+ return 2
+ default:
+ return 1
+ }
+}
+
+func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
+ var selected *dbent.AuthIdentity
+ for _, record := range records {
+ if record.UserID != userID {
+ continue
+ }
+ if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
+ for _, record := range records {
+ if record.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
+ var selected *dbent.AuthIdentityChannel
+ for _, record := range records {
+ if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
+ continue
+ }
+ if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
+ for _, record := range records {
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return false, fmt.Errorf("sql executor is not configured")
+ }
+
+ result, err := exec.ExecContext(ctx, `
+INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ input.UserID,
+ strings.TrimSpace(input.ProviderType),
+ string(input.GrantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ var result *dbent.IdentityAdoptionDecision
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ releaseLocks, err := lockRepositoryScopedKeys(
+ txCtx,
+ client,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseLocks()
+
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
+ dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
+ col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.NEQ(col, input.PendingAuthSessionID),
+ ))
+ }),
+ ).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return err
+ }
+ }
+
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+
+ decisionID, err := create.
+ OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
+ UpdateNewValues().
+ ID(txCtx)
+ if err != nil {
+ return err
+ }
+
+ result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
+ keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
+ if identityID != nil && *identityID > 0 {
+ keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
+ }
+ return keys
+}
+
+func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
+ return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
+ Only(ctx)
+}
+
+func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastLoginAt(loginAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastActiveAt(activeAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
+FROM user_avatars
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
+ON CONFLICT (user_id) DO UPDATE SET
+ storage_provider = EXCLUDED.storage_provider,
+ storage_key = EXCLUDED.storage_key,
+ url = EXCLUDED.url,
+ content_type = EXCLUDED.content_type,
+ byte_size = EXCLUDED.byte_size,
+ sha256 = EXCLUDED.sha256,
+ updated_at = NOW()`,
+ userID,
+ strings.TrimSpace(input.StorageProvider),
+ strings.TrimSpace(input.StorageKey),
+ strings.TrimSpace(input.URL),
+ strings.TrimSpace(input.ContentType),
+ input.ByteSize,
+ strings.TrimSpace(input.SHA256),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: strings.TrimSpace(input.StorageProvider),
+ StorageKey: strings.TrimSpace(input.StorageKey),
+ URL: strings.TrimSpace(input.URL),
+ ContentType: strings.TrimSpace(input.ContentType),
+ ByteSize: input.ByteSize,
+ SHA256: strings.TrimSpace(input.SHA256),
+ }, nil
+}
+
+func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return err
+ }
+ _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
+ return err
+}
+
+func copyMetadata(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
+ if channel == nil {
+ return nil
+ }
+
+ canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
+ canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
+ channelProviderType := strings.TrimSpace(channel.ProviderType)
+ channelProviderKey := strings.TrimSpace(channel.ProviderKey)
+
+ if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
+ return ErrAuthIdentityChannelProviderMismatch
+ }
+
+ return nil
+}
+
+func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
+ return exec
+ }
+ }
+ if fallback != nil {
+ return fallback
+ }
+ return sqlExecutorFromEntClient(client)
+}
+
+func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+ return exec, nil
+}
+
+func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
+ if client == nil {
+ return nil
+ }
+
+ clientValue := reflect.ValueOf(client).Elem()
+ configValue := clientValue.FieldByName("config")
+ driverValue := configValue.FieldByName("driver")
+ if !driverValue.IsValid() {
+ return nil
+ }
+
+ driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
+ exec, ok := driver.(sqlQueryExecutor)
+ if !ok {
+ return nil
+ }
+ return exec
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
new file mode 100644
index 00000000..d4f9e8b3
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -0,0 +1,578 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/suite"
+)
+
+type UserProfileIdentityRepoSuite struct {
+ suite.Suite
+ ctx context.Context
+ client *dbent.Client
+ repo *userRepository
+}
+
+func TestUserProfileIdentityRepoSuite(t *testing.T) {
+ suite.Run(t, new(UserProfileIdentityRepoSuite))
+}
+
+func (s *UserProfileIdentityRepoSuite) SetupTest() {
+ s.ctx = context.Background()
+ s.client = testEntClient(s.T())
+ s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
+
+ _, err := integrationDB.ExecContext(s.ctx, `
+TRUNCATE TABLE
+ identity_adoption_decisions,
+ auth_identity_channels,
+ auth_identities,
+ pending_auth_sessions,
+ user_provider_default_grants,
+ user_avatars
+RESTART IDENTITY`)
+ s.Require().NoError(err)
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
+ s.T().Helper()
+
+ user, err := s.client.User.Create().
+ SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
+ SetPasswordHash("test-password-hash").
+ SetRole("user").
+ SetStatus("active").
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return user
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
+ s.T().Helper()
+
+ session, err := s.client.PendingAuthSession.Create().
+ SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
+ SetIntent("bind_current_user").
+ SetProviderType(key.ProviderType).
+ SetProviderKey(key.ProviderKey).
+ SetProviderSubject(key.ProviderSubject).
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return session
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
+ user := s.mustCreateUser("canonical-channel")
+
+ verifiedAt := time.Now().UTC().Truncate(time.Second)
+ created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ Channel: "mp",
+ ChannelAppID: "wx-app",
+ ChannelSubject: "openid-123",
+ },
+ Issuer: stringPtr("https://issuer.example"),
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{"unionid": "union-123"},
+ ChannelMetadata: map[string]any{"openid": "openid-123"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(created.Identity)
+ s.Require().NotNil(created.Channel)
+
+ canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, canonical.User.ID)
+ s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
+ s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
+
+ channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, channel.User.ID)
+ s.Require().Equal(created.Identity.ID, channel.Identity.ID)
+ s.Require().Equal(created.Channel.ID, channel.Channel.ID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
+ owner := s.mustCreateUser("owner")
+ other := s.mustCreateUser("other")
+
+ first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "first"},
+ ChannelMetadata: map[string]any{"scope": "read"},
+ })
+ s.Require().NoError(err)
+
+ second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "second"},
+ ChannelMetadata: map[string]any{"scope": "write"},
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.Identity.ID, second.Identity.ID)
+ s.Require().Equal(first.Channel.ID, second.Channel.ID)
+ s.Require().Equal("second", second.Identity.Metadata["username"])
+ s.Require().Equal("write", second.Channel.Metadata["scope"])
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-2",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() {
+ user := s.mustCreateUser("wechat-legacy-alias")
+
+ legacyIdentity, err := s.client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy-alias"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ legacyChannel, err := s.client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("oa").
+ SetChannelAppID("wx-app-legacy").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy-alias"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ Channel: "oa",
+ ChannelAppID: "wx-app-legacy",
+ ChannelSubject: "openid-legacy-123",
+ },
+ Metadata: map[string]any{"source": "canonical-bind"},
+ ChannelMetadata: map[string]any{"scene": "canonical-bind"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(bound)
+ s.Require().NotNil(bound.Identity)
+ s.Require().NotNil(bound.Channel)
+ s.Require().Equal(legacyIdentity.ID, bound.Identity.ID)
+ s.Require().Equal(legacyChannel.ID, bound.Channel.ID)
+ s.Require().Equal("wechat-main", bound.Identity.ProviderKey)
+ s.Require().Equal("wechat-main", bound.Channel.ProviderKey)
+ s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"])
+ s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"])
+
+ identityCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, identityCount)
+
+ channelCount, err := s.client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("oa"),
+ authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, channelCount)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-create")
+
+ _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-create-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "app-mismatch",
+ ChannelSubject: "openid-create-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-bind")
+
+ _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-bind-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-legacy",
+ Channel: "oa",
+ ChannelAppID: "wx-app-bind-mismatch",
+ ChannelSubject: "openid-bind-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
+ user := s.mustCreateUser("tx-rollback")
+ expectedErr := errors.New("rollback")
+
+ err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ },
+ })
+ s.Require().NoError(err)
+
+ inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "oidc",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+ return expectedErr
+ })
+ s.Require().ErrorIs(err, expectedErr)
+
+ _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ })
+ s.Require().True(dbent.IsNotFound(err))
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
+ user.ID,
+ "oidc",
+ string(ProviderGrantReasonFirstBind),
+ ).Scan(&count))
+ s.Require().Zero(count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
+ user := s.mustCreateUser("grant")
+
+ inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().False(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonSignup,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2`,
+ user.ID,
+ "wechat",
+ ).Scan(&count))
+ s.Require().Equal(2, count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
+ user := s.mustCreateUser("adoption")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ s.Require().NoError(err)
+
+ session := s.mustCreatePendingAuthSession(identity.IdentityRef())
+
+ first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().True(first.AdoptDisplayName)
+ s.Require().False(first.AdoptAvatar)
+ s.Require().Nil(first.IdentityID)
+
+ second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.ID, second.ID)
+ s.Require().NotNil(second.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *second.IdentityID)
+ s.Require().True(second.AdoptAvatar)
+
+ loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(second.ID, loaded.ID)
+ s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() {
+ user := s.mustCreateUser("adoption-reassign")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption-reassign",
+ },
+ })
+ s.Require().NoError(err)
+
+ firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
+ firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: firstSession.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(firstDecision.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID)
+
+ secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
+ secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: secondSession.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(secondDecision.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID)
+
+ reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(reloadedFirst.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
+ user := s.mustCreateUser("avatar-only-update")
+
+ model, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(model)
+
+ err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ if err != nil {
+ return err
+ }
+ return s.repo.Update(txCtx, model)
+ })
+ s.Require().NoError(err)
+
+ avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(avatar)
+ s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
+ user := s.mustCreateUser("avatar")
+
+ inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: "data:image/png;base64,QUJD",
+ ContentType: "image/png",
+ ByteSize: 3,
+ SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
+ })
+ s.Require().NoError(err)
+ s.Require().Equal("inline", inlineAvatar.StorageProvider)
+ s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
+
+ loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("image/png", loadedAvatar.ContentType)
+ s.Require().Equal(3, loadedAvatar.ByteSize)
+
+ _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ s.Require().NoError(err)
+
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
+ s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
+ s.Require().Zero(loadedAvatar.ByteSize)
+
+ s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(loadedAvatar)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
+ user := s.mustCreateUser("activity")
+ loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ activeAt := loginAt.Add(5 * time.Minute)
+
+ s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
+ s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
+
+ var storedLoginAt sqlNullTime
+ var storedActiveAt sqlNullTime
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT last_login_at, last_active_at
+FROM users
+WHERE id = $1`,
+ user.ID,
+ ).Scan(&storedLoginAt, &storedActiveAt))
+ s.Require().True(storedLoginAt.Valid)
+ s.Require().True(storedActiveAt.Valid)
+ s.Require().True(storedLoginAt.Time.Equal(loginAt))
+ s.Require().True(storedActiveAt.Time.Equal(activeAt))
+}
+
+type sqlNullTime struct {
+ Time time.Time
+ Valid bool
+}
+
+func (t *sqlNullTime) Scan(value any) error {
+ switch v := value.(type) {
+ case time.Time:
+ t.Time = v
+ t.Valid = true
+ return nil
+ case nil:
+ t.Time = time.Time{}
+ t.Valid = false
+ return nil
+ default:
+ return fmt.Errorf("unsupported scan type %T", value)
+ }
+}
+
+func stringPtr(v string) *string {
+ return &v
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_unit_test.go b/backend/internal/repository/user_profile_identity_repo_unit_test.go
new file mode 100644
index 00000000..689f32f9
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_unit_test.go
@@ -0,0 +1,212 @@
+package repository
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "wechat-legacy@example.com",
+ Username: "wechat-legacy",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, user))
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy-alias"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyChannel, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("oa").
+ SetChannelAppID("wx-app-legacy").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy-alias"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ Channel: "oa",
+ ChannelAppID: "wx-app-legacy",
+ ChannelSubject: "openid-legacy-123",
+ },
+ Metadata: map[string]any{"source": "canonical-bind"},
+ ChannelMetadata: map[string]any{"scene": "canonical-bind"},
+ })
+ require.NoError(t, err)
+ require.NotNil(t, bound)
+ require.NotNil(t, bound.Identity)
+ require.NotNil(t, bound.Channel)
+ require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
+ require.Equal(t, legacyChannel.ID, bound.Channel.ID)
+ require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
+ require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
+
+ reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
+ require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
+
+ reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
+ require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("oa"),
+ authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, channelCount)
+}
+
+func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "repo-adoption@example.com",
+ Username: "repo-adoption",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, user))
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-repo-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-repo-adoption").
+ SetIntent("bind_current_user").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-repo-adoption").
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type adoptionResult struct {
+ decision *dbent.IdentityAdoptionDecision
+ err error
+ }
+
+ input := IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ }
+
+ results := make(chan adoptionResult, 2)
+ go func() {
+ decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ require.NoError(t, first.err)
+ require.NoError(t, second.err)
+ require.NotNil(t, first.decision)
+ require.NotNil(t, second.decision)
+ require.Equal(t, first.decision.ID, second.decision.ID)
+
+ count, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+
+ loaded, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, loaded.IdentityID)
+ require.Equal(t, identity.ID, *loaded.IdentityID)
+ require.True(t, loaded.AdoptDisplayName)
+ require.True(t, loaded.AdoptAvatar)
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 913e1c40..c5db3dc4 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -11,12 +11,17 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
@@ -47,12 +52,33 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
+ txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
} else {
- // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
- txClient = r.client
+ // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
+ }
+
+ releaseEmailLock, err := lockRepositoryScopedKeys(
+ txCtx,
+ txClient,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ normalizedEmailUniquenessLockKey(userIn.Email),
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseEmailLock()
+
+ if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
+ return err
}
created, err := txClient.User.Create().
@@ -64,12 +90,18 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- Save(ctx)
+ SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
+ SetNillableLastLoginAt(userIn.LastLoginAt).
+ SetNillableLastActiveAt(userIn.LastActiveAt).
+ Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
- if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
+ if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
+ return err
+ }
+ if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
@@ -101,10 +133,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
- m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ matches, err := r.client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
if err != nil {
- return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
+ return nil, err
}
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
+ }
+ m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
@@ -129,14 +171,41 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
+ txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
} else {
- // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
- txClient = r.client
+ // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
}
+ releaseEmailLock, err := lockRepositoryScopedKeys(
+ txCtx,
+ txClient,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ normalizedEmailUniquenessLockKey(userIn.Email),
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseEmailLock()
+
+ if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
+ return err
+ }
+
+ existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ oldEmail := existing.Email
+
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
@@ -151,15 +220,27 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
+ if userIn.SignupSource != "" {
+ updateOp = updateOp.SetSignupSource(userIn.SignupSource)
+ }
+ if userIn.LastLoginAt != nil {
+ updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
+ }
+ if userIn.LastActiveAt != nil {
+ updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
+ }
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
- updated, err := updateOp.Save(ctx)
+ updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
- if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
+ if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
+ return err
+ }
+ if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
@@ -173,14 +254,146 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
+func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil || userID <= 0 {
+ return nil
+ }
+
+ subject := normalizeEmailAuthIdentitySubject(email)
+ if subject == "" {
+ return nil
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": source}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if !isSQLNoRowsError(err) {
+ return err
+ }
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ return nil
+}
+
+func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
+ newSubject := normalizeEmailAuthIdentitySubject(newEmail)
+ if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := clientFromContext(ctx, client).AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func normalizeEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return ""
+ }
+ if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
+ return ""
+ }
+ return normalized
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
- affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
+ tx, err := r.client.Tx(ctx)
+ if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+
+ var txClient *dbent.Client
+ if err == nil {
+ defer func() { _ = tx.Rollback() }()
+ txClient = tx.Client()
+ } else {
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
+ }
+
+ identityIDs, err := txClient.AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(id)).
+ IDs(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if len(identityIDs) > 0 {
+ if _, err := txClient.IdentityAdoptionDecision.Update().
+ Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
+ ClearIdentityID().
+ Save(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if _, err := txClient.AuthIdentityChannel.Delete().
+ Where(authidentitychannel.IdentityIDIn(identityIDs...)).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if _, err := txClient.AuthIdentity.Delete().
+ Where(authidentity.UserIDEQ(id)).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ }
+
+ affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
+
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ }
return nil
}
@@ -298,8 +511,13 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+ if sortBy == "last_used_at" {
+ return userLastUsedAtOrder(sortOrder)
+ }
+
var field string
defaultField := true
+ nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
@@ -322,6 +540,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
+ case "last_active_at":
+ field = dbuser.FieldLastActiveAt
+ defaultField = false
+ nullsLastField = true
default:
field = dbuser.FieldID
}
@@ -330,14 +552,92 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
+ dbent.Asc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
+ dbent.Desc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ const query = `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `
+
+ rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
+ return nil, scanErr
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
+func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
+ orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
+ return func(s *entsql.Selector) {
+ subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
+ s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
+ s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
+ }
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){
+ orderExpr("ASC", "FIRST", entsql.Asc),
+ }
+ }
+ return []func(*entsql.Selector){
+ orderExpr("DESC", "LAST", entsql.Desc),
+ }
+}
+
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
@@ -436,17 +736,68 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
- return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
+ return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
+}
+
+func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil {
+ return nil
+ }
+
+ matches, err := client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, match := range matches {
+ if match.ID != userID {
+ return service.ErrEmailExists
+ }
+ }
+ return nil
+}
+
+func userEmailLookupPredicate(email string) predicate.User {
+ normalized := normalizeEmailLookupValue(email)
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.P(func(b *entsql.Builder) {
+ b.WriteString("LOWER(TRIM(").
+ Ident(s.C(dbuser.FieldEmail)).
+ WriteString(")) = ").
+ Arg(normalized)
+ }))
+ })
+}
+
+func normalizeEmailLookupValue(email string) string {
+ return strings.ToLower(strings.TrimSpace(email))
+}
+
+func normalizedEmailUniquenessLockKey(email string) string {
+ normalized := normalizeEmailLookupValue(email)
+ if normalized == "" {
+ return ""
+ }
+ return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
- return client.UserAllowedGroup.Create().
+ err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
+ if isSQLNoRowsError(err) {
+ return nil
+ }
+ return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
@@ -546,6 +897,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
+ if isSQLNoRowsError(err) {
+ return nil
+ }
return err
}
}
@@ -558,10 +912,24 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
+ dst.SignupSource = src.SignupSource
+ dst.LastLoginAt = src.LastLoginAt
+ dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+func userSignupSourceOrDefault(signupSource string) string {
+ switch strings.TrimSpace(strings.ToLower(signupSource)) {
+ case "", "email":
+ return "email"
+ case "linuxdo", "wechat", "oidc":
+ return strings.TrimSpace(strings.ToLower(signupSource))
+ default:
+ return "email"
+ }
+}
+
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go
new file mode 100644
index 00000000..fddd82c5
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_identity_integration_test.go
@@ -0,0 +1,86 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
+ user := &service.User{
+ Email: "repo-create@example.com",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ identity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("repo-create@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, identity.UserID)
+}
+
+func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
+ user := &service.User{
+ Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ count, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(count)
+}
+
+func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
+ user := s.mustCreateUser(&service.User{
+ Email: "before-update@example.com",
+ })
+
+ user.Email = "after-update@example.com"
+ s.Require().NoError(s.repo.Update(s.ctx, user))
+
+ newIdentity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("after-update@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, newIdentity.UserID)
+
+ oldCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("before-update@example.com"),
+ ).
+ Count(context.Background())
+ s.Require().NoError(err)
+ s.Require().Zero(oldCount)
+}
diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go
new file mode 100644
index 00000000..7da3db9b
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go
@@ -0,0 +1,227 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+ db.SetMaxOpenConns(10)
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return newUserRepositoryWithSQL(client, db), client
+}
+
+func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ got, err := repo.GetByEmail(ctx, "legacy@example.com")
+ require.NoError(t, err)
+ require.Equal(t, " Legacy@Example.com ", got.Email)
+}
+
+func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ")
+ require.NoError(t, err)
+ require.True(t, exists)
+}
+
+func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Existing@Example.com ",
+ Username: "existing-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ err = repo.Create(ctx, &service.User{
+ Email: "existing@example.com",
+ Username: "duplicate-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.ErrorIs(t, err, service.ErrEmailExists)
+}
+
+func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ first := &service.User{
+ Email: " Existing@Example.com ",
+ Username: "existing-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, first))
+
+ second := &service.User{
+ Email: "second@example.com",
+ Username: "second-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, second))
+
+ second.Email = " existing@example.com "
+ err := repo.Update(ctx, second)
+ require.ErrorIs(t, err, service.ErrEmailExists)
+}
+
+func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ _, err := client.User.Create().
+ SetEmail("Conflict@Example.com").
+ SetUsername("conflict-user-1").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.User.Create().
+ SetEmail(" conflict@example.com ").
+ SetUsername("conflict-user-2").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = repo.GetByEmail(ctx, "conflict@example.com")
+ require.Error(t, err)
+ require.ErrorContains(t, err, "normalized email lookup matched multiple users")
+}
+
+func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.User.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type createResult struct {
+ err error
+ }
+
+ results := make(chan createResult, 2)
+ go func() {
+ results <- createResult{err: repo.Create(ctx, &service.User{
+ Email: " Race@Example.com ",
+ Username: "race-user-1",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ results <- createResult{err: repo.Create(ctx, &service.User{
+ Email: "race@example.com",
+ Username: "race-user-2",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ errors := []error{first.err, second.err}
+ successes := 0
+ conflicts := 0
+ for _, err := range errors {
+ switch err {
+ case nil:
+ successes++
+ case service.ErrEmailExists:
+ conflicts++
+ default:
+ t.Fatalf("unexpected create error: %v", err)
+ }
+ }
+ require.Equal(t, 1, successes)
+ require.Equal(t, 1, conflicts)
+
+ count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+}
diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go
index f5d0f9ff..13a605a2 100644
--- a/backend/internal/repository/user_repo_integration_test.go
+++ b/backend/internal/repository/user_repo_integration_test.go
@@ -8,6 +8,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -26,6 +28,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
@@ -122,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() {
s.Require().Equal(user.ID, got.ID)
}
+func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() {
+ user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
+
+ got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ")
+ s.Require().NoError(err, "GetByEmail normalized lookup")
+ s.Require().Equal(user.ID, got.ID)
+}
+
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
+func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() {
+ s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
+
+ exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ")
+ s.Require().NoError(err, "ExistsByEmail normalized lookup")
+ s.Require().True(exists)
+}
+
func (s *UserRepoSuite) TestUpdate() {
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
@@ -140,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
s.Require().Equal("updated", updated.Username)
}
+func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
+ user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
+
+ identityCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, identityCount)
+
+ got, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ got.Username = "updated"
+ s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
+
+ updated, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("updated", updated.Username)
+}
+
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
@@ -150,6 +194,39 @@ func (s *UserRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
+func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() {
+ user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"})
+
+ identity, err := s.client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("delete-oauth-subject").
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ _, err = s.client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("open").
+ SetChannelAppID("app-id").
+ SetChannelSubject("openid-123").
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ err = s.repo.Delete(s.ctx, user.ID)
+ s.Require().NoError(err)
+
+ identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(identityCount)
+
+ channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(channelCount)
+}
+
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index ab84b0e9..3a15bc10 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -4,11 +4,30 @@ package repository
import (
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
+func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) {
+ s.T().Helper()
+
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID})
+
+ _, err := integrationDB.ExecContext(
+ s.ctx,
+ `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at)
+ VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`,
+ userID,
+ apiKey.ID,
+ account.ID,
+ createdAt.UTC(),
+ )
+ s.Require().NoError(err)
+}
+
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
@@ -36,4 +55,110 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID)
}
+func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
+ lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created := s.mustCreateUser(&service.User{
+ Email: "identity-meta@example.com",
+ SignupSource: "linuxdo",
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ })
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("linuxdo", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
+ created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
+ lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created.SignupSource = "oidc"
+ created.LastLoginAt = &lastLoginAt
+ created.LastActiveAt = &lastActiveAt
+
+ s.Require().NoError(s.repo.Update(s.ctx, created))
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("oidc", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
+ earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
+ later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
+ s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
+ s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_active_at",
+ SortOrder: "asc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("earlier-active@example.com", users[0].Email)
+ s.Require().Equal("later-active@example.com", users[1].Email)
+ s.Require().Equal("nil-active@example.com", users[2].Email)
+}
+
+func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second)
+ newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second)
+
+ userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"})
+ userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"})
+ s.mustInsertUsageLog(userWithUsage.ID, older)
+ s.mustInsertUsageLog(userWithUsage.ID, newer)
+
+ got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID})
+ s.Require().NoError(err)
+ s.Require().Contains(got, userWithUsage.ID)
+ s.Require().NotContains(got, userWithoutUsage.ID)
+ s.Require().NotNil(got[userWithUsage.ID])
+ s.Require().True(got[userWithUsage.ID].Equal(newer))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() {
+ lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second)
+ lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second)
+ lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second)
+
+ nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"})
+ wrongSource := s.mustCreateUser(&service.User{
+ Email: "active-not-usage@example.com",
+ LastActiveAt: &lastActiveVeryRecent,
+ })
+ rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"})
+
+ s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder)
+ s.mustInsertUsageLog(rightSource.ID, lastUsedNewer)
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_used_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal(rightSource.ID, users[0].ID)
+ s.Require().Equal(wrongSource.ID, users[1].ID)
+ s.Require().Equal(nilUsage.ID, users[2].ID)
+}
+
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index b686b986..d2b108f5 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -50,6 +50,7 @@ func TestAPIContracts(t *testing.T) {
"data": {
"id": 1,
"email": "alice@example.com",
+ "email_bound": true,
"username": "alice",
"role": "user",
"balance": 12.5,
@@ -63,6 +64,123 @@ func TestAPIContracts(t *testing.T) {
"balance_notify_threshold": null,
"balance_notify_extra_emails": null,
"total_recharged": 0,
+ "linuxdo_bound": false,
+ "oidc_bound": false,
+ "wechat_bound": false,
+ "identities": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
+ "identity_bindings": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
+ "auth_bindings": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
"run_mode": "standard"
}
}`,
@@ -479,7 +597,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
- service.SettingKeyOIDCConnectUsePKCE: "false",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
@@ -500,10 +618,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- service.SettingKeyOpsMonitoringEnabled: "false",
- service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
- service.SettingKeyOpsQueryModeDefault: "auto",
- service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingKeyOpsMonitoringEnabled: "false",
+ service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
+ service.SettingKeyOpsQueryModeDefault: "auto",
+ service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay,
+ service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat,
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ "openai_advanced_scheduler_enabled": "true",
})
},
method: http.MethodGet,
@@ -549,7 +672,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
- "oidc_connect_use_pkce": false,
+ "oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
@@ -567,6 +690,27 @@ func TestAPIContracts(t *testing.T) {
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_subscriptions": [],
@@ -592,6 +736,11 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "easypay_alipay",
+ "payment_visible_method_wxpay_source": "official_wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
@@ -618,7 +767,215 @@ func TestAPIContracts(t *testing.T) {
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
- "account_quota_notify_emails": []
+ "account_quota_notify_emails": [],
+ "wechat_connect_enabled": false,
+ "wechat_connect_app_id": "",
+ "wechat_connect_app_secret_configured": false,
+ "wechat_connect_mode": "open",
+ "wechat_connect_open_enabled": false,
+ "wechat_connect_open_app_id": "",
+ "wechat_connect_open_app_secret_configured": false,
+ "wechat_connect_mp_enabled": false,
+ "wechat_connect_mp_app_id": "",
+ "wechat_connect_mp_app_secret_configured": false,
+ "wechat_connect_mobile_enabled": false,
+ "wechat_connect_mobile_app_id": "",
+ "wechat_connect_mobile_app_secret_configured": false,
+ "wechat_connect_redirect_url": "",
+ "wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
+ "wechat_connect_scopes": "snsapi_login"
+ }
+ }`,
+ },
+ {
+ name: "GET /api/v1/admin/settings falls back to config oauth defaults",
+ setup: func(t *testing.T, deps *contractDeps) {
+ t.Helper()
+ deps.cfg.OIDC = config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "ConfigOIDC",
+ ClientID: "oidc-config-client",
+ ClientSecret: "oidc-config-secret",
+ IssuerURL: "https://issuer.example.com",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256,ES256,PS256",
+ ClockSkewSeconds: 120,
+ }
+ deps.cfg.WeChat = config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ Mode: "open",
+ Scopes: "snsapi_login",
+ FrontendRedirectURL: "/auth/wechat/callback",
+ }
+ deps.settingRepo.SetAll(map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyEmailVerifyEnabled: "false",
+ service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ })
+ },
+ method: http.MethodGet,
+ path: "/api/v1/admin/settings",
+ wantStatus: http.StatusOK,
+ wantJSON: `{
+ "code": 0,
+ "message": "success",
+ "data": {
+ "registration_enabled": true,
+ "email_verify_enabled": false,
+ "registration_email_suffix_whitelist": [],
+ "promo_code_enabled": true,
+ "password_reset_enabled": false,
+ "frontend_url": "",
+ "invitation_code_enabled": false,
+ "totp_enabled": false,
+ "totp_encryption_key_configured": false,
+ "smtp_host": "",
+ "smtp_port": 587,
+ "smtp_username": "",
+ "smtp_password_configured": false,
+ "smtp_from_email": "",
+ "smtp_from_name": "",
+ "smtp_use_tls": false,
+ "turnstile_enabled": false,
+ "turnstile_site_key": "",
+ "turnstile_secret_key_configured": false,
+ "linuxdo_connect_enabled": false,
+ "linuxdo_connect_client_id": "",
+ "linuxdo_connect_client_secret_configured": false,
+ "linuxdo_connect_redirect_url": "",
+ "oidc_connect_enabled": true,
+ "oidc_connect_provider_name": "ConfigOIDC",
+ "oidc_connect_client_id": "oidc-config-client",
+ "oidc_connect_client_secret_configured": true,
+ "oidc_connect_issuer_url": "https://issuer.example.com",
+ "oidc_connect_discovery_url": "",
+ "oidc_connect_authorize_url": "",
+ "oidc_connect_token_url": "",
+ "oidc_connect_userinfo_url": "",
+ "oidc_connect_jwks_url": "",
+ "oidc_connect_scopes": "openid email profile",
+ "oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ "oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
+ "oidc_connect_token_auth_method": "client_secret_post",
+ "oidc_connect_use_pkce": true,
+ "oidc_connect_validate_id_token": true,
+ "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
+ "oidc_connect_clock_skew_seconds": 120,
+ "oidc_connect_require_email_verified": false,
+ "oidc_connect_userinfo_email_path": "",
+ "oidc_connect_userinfo_id_path": "",
+ "oidc_connect_userinfo_username_path": "",
+ "site_name": "Sub2API",
+ "site_logo": "",
+ "site_subtitle": "Subscription to API Conversion Platform",
+ "api_base_url": "",
+ "contact_info": "",
+ "doc_url": "",
+ "home_content": "",
+ "hide_ccs_import_button": false,
+ "purchase_subscription_enabled": false,
+ "purchase_subscription_url": "",
+ "table_default_page_size": 20,
+ "table_page_size_options": [10, 20, 50],
+ "custom_menu_items": [],
+ "custom_endpoints": [],
+ "default_concurrency": 0,
+ "default_balance": 0,
+ "default_subscriptions": [],
+ "enable_model_fallback": false,
+ "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
+ "fallback_model_openai": "gpt-4o",
+ "fallback_model_gemini": "gemini-2.5-pro",
+ "fallback_model_antigravity": "gemini-2.5-pro",
+ "enable_identity_patch": true,
+ "identity_patch_prompt": "",
+ "ops_monitoring_enabled": false,
+ "ops_realtime_monitoring_enabled": true,
+ "ops_query_mode_default": "auto",
+ "ops_metrics_interval_seconds": 60,
+ "min_claude_code_version": "",
+ "max_claude_code_version": "",
+ "allow_ungrouped_key_scheduling": false,
+ "backend_mode_enabled": false,
+ "enable_fingerprint_unification": true,
+ "enable_metadata_passthrough": false,
+ "enable_cch_signing": false,
+ "web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "",
+ "payment_visible_method_wxpay_source": "",
+ "payment_visible_method_alipay_enabled": false,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": false,
+ "payment_enabled": false,
+ "payment_min_amount": 0,
+ "payment_max_amount": 0,
+ "payment_daily_limit": 0,
+ "payment_order_timeout_minutes": 0,
+ "payment_max_pending_orders": 0,
+ "payment_enabled_types": null,
+ "payment_balance_disabled": false,
+ "payment_balance_recharge_multiplier": 0,
+ "payment_recharge_fee_rate": 0,
+ "payment_load_balance_strategy": "",
+ "payment_product_name_prefix": "",
+ "payment_product_name_suffix": "",
+ "payment_help_image_url": "",
+ "payment_help_text": "",
+ "payment_cancel_rate_limit_enabled": false,
+ "payment_cancel_rate_limit_max": 0,
+ "payment_cancel_rate_limit_window": 0,
+ "payment_cancel_rate_limit_unit": "",
+ "payment_cancel_rate_limit_window_mode": "",
+ "balance_low_notify_enabled": false,
+ "account_quota_notify_enabled": false,
+ "balance_low_notify_threshold": 0,
+ "balance_low_notify_recharge_url": "",
+ "account_quota_notify_emails": [],
+ "wechat_connect_enabled": true,
+ "wechat_connect_app_id": "wx-open-config",
+ "wechat_connect_app_secret_configured": true,
+ "wechat_connect_mode": "open",
+ "wechat_connect_open_enabled": true,
+ "wechat_connect_open_app_id": "wx-open-config",
+ "wechat_connect_open_app_secret_configured": true,
+ "wechat_connect_mp_enabled": false,
+ "wechat_connect_mp_app_id": "wx-open-config",
+ "wechat_connect_mp_app_secret_configured": true,
+ "wechat_connect_mobile_enabled": false,
+ "wechat_connect_mobile_app_id": "wx-open-config",
+ "wechat_connect_mobile_app_secret_configured": true,
+ "wechat_connect_redirect_url": "",
+ "wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
+ "wechat_connect_scopes": "snsapi_login",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false
}
}`,
},
@@ -665,6 +1022,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
+ cfg *config.Config
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
@@ -785,6 +1143,7 @@ func newContractDeps(t *testing.T) *contractDeps {
return &contractDeps{
now: now,
router: r,
+ cfg: cfg,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
@@ -858,6 +1217,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
+func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -894,6 +1265,26 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
return errors.New("not implemented")
}
+func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return nil
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index ed2578c8..06e3355e 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
+func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -189,6 +214,14 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
+func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go
index 46482af3..ae53037e 100644
--- a/backend/internal/server/middleware/backend_mode_guard.go
+++ b/backend/internal/server/middleware/backend_mode_guard.go
@@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
}
}
+func backendModeAllowsAuthPath(path string) bool {
+ path = strings.ToLower(strings.TrimSpace(path))
+ for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
+ if strings.HasSuffix(path, suffix) {
+ return true
+ }
+ }
+
+ for _, suffix := range []string{
+ "/auth/oauth/linuxdo/callback",
+ "/auth/oauth/wechat/callback",
+ "/auth/oauth/wechat/payment/callback",
+ "/auth/oauth/oidc/callback",
+ "/auth/oauth/linuxdo/complete-registration",
+ "/auth/oauth/wechat/complete-registration",
+ "/auth/oauth/oidc/complete-registration",
+ "/auth/oauth/linuxdo/create-account",
+ "/auth/oauth/wechat/create-account",
+ "/auth/oauth/oidc/create-account",
+ "/auth/oauth/linuxdo/bind-login",
+ "/auth/oauth/wechat/bind-login",
+ "/auth/oauth/oidc/bind-login",
+ } {
+ if strings.HasSuffix(path, suffix) {
+ return true
+ }
+ }
+
+ return strings.Contains(path, "/auth/oauth/pending/")
+}
+
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
-// Allows: login, login/2fa, logout, refresh (admin needs these).
-// Blocks: register, forgot-password, reset-password, OAuth, etc.
+// Allows the minimal auth surface admins still need in backend mode, including
+// OAuth callbacks and pending continuations. Handler-level backend mode checks
+// still enforce admin-only login and forbid self-service registration.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
- path := c.Request.URL.Path
- // Allow login, 2FA, logout, refresh, public settings
- allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
- for _, suffix := range allowedSuffixes {
- if strings.HasSuffix(path, suffix) {
- c.Next()
- return
- }
+ if backendModeAllowsAuthPath(c.Request.URL.Path) {
+ c.Next()
+ return
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go
index 8878ebc9..bd77677b 100644
--- a/backend/internal/server/middleware/backend_mode_guard_test.go
+++ b/backend/internal/server/middleware/backend_mode_guard_test.go
@@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
+ {
+ name: "enabled_blocks_linuxdo_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_linuxdo_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_wechat_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_wechat_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_wechat_payment_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/payment/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_wechat_payment_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/payment/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_oidc_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_oidc_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_exchange",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/exchange",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_send_verify_code",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/send-verify-code",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/bind-login",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_provider_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/bind-login",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_provider_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_legacy_complete_registration",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/complete-registration",
+ wantStatus: http.StatusOK,
+ },
{
name: "enabled_blocks_register",
enabled: "true",
diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go
index 4aceb355..48cb9004 100644
--- a/backend/internal/server/middleware/jwt_auth.go
+++ b/backend/internal/server/middleware/jwt_auth.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "context"
"errors"
"strings"
@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
- return JWTAuthMiddleware(jwtAuth(authService, userService))
+ return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
+}
+
+type jwtUserReader interface {
+ GetByID(ctx context.Context, id int64) (*service.User, error)
+}
+
+type userActivityToucher interface {
+ TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
-func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
+func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
+ if activityToucher != nil {
+ activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
+ }
c.Next()
}
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index c483a51e..84fd6967 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
+func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
+ return nil
+}
+
+type recordingActivityToucher struct {
+ userIDs []int64
+}
+
+func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
+ if user == nil {
+ return
+ }
+ r.userIDs = append(r.userIDs, user.ID)
+}
+
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
+func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
+ user := &service.User{
+ ID: 1,
+ Email: "test@example.com",
+ Role: "user",
+ Status: service.StatusActive,
+ Concurrency: 5,
+ TokenVersion: 1,
+ }
+
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
+ cfg.JWT.AccessTokenExpireMinutes = 60
+
+ userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ toucher := &recordingActivityToucher{}
+
+ r := gin.New()
+ r.Use(jwtAuth(authSvc, userSvc, toucher))
+ r.GET("/protected", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ token, err := authSvc.GenerateToken(user)
+ require.NoError(t, err)
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/protected", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ r.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusOK, w.Code)
+ require.Equal(t, []int64{1}, toucher.userIDs)
+}
+
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 7021ab2e..398c0351 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -96,7 +96,8 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
- strings.HasPrefix(path, "/responses")
+ strings.HasPrefix(path, "/responses") ||
+ strings.HasPrefix(path, "/images")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 9af0fd8e..84c963ec 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -212,6 +212,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
+ users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index c143b030..642a2103 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -63,14 +63,90 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
+ auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.LinuxDoOAuthStart(c)
+ })
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
+ auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
+ auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.WeChatOAuthStart(c)
+ })
+ auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
+ auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
+ auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
+ auth.POST("/oauth/pending/exchange",
+ rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.ExchangePendingOAuthCompletion,
+ )
+ auth.POST("/oauth/pending/send-verify-code",
+ rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.SendPendingOAuthVerifyCode,
+ )
+ auth.POST("/oauth/pending/create-account",
+ rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreatePendingOAuthAccount,
+ )
+ auth.POST("/oauth/pending/bind-login",
+ rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindPendingOAuthLogin,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/linuxdo/bind-login",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindLinuxDoOAuthLogin,
+ )
+ auth.POST("/oauth/linuxdo/create-account",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateLinuxDoOAuthAccount,
+ )
+ auth.POST("/oauth/wechat/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteWeChatOAuthRegistration,
+ )
+ auth.POST("/oauth/wechat/bind-login",
+ rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindWeChatOAuthLogin,
+ )
+ auth.POST("/oauth/wechat/create-account",
+ rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateWeChatOAuthAccount,
+ )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
+ auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.OIDCOAuthStart(c)
+ })
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
@@ -78,6 +154,18 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
+ auth.POST("/oauth/oidc/bind-login",
+ rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindOIDCOAuthLogin,
+ )
+ auth.POST("/oauth/oidc/create-account",
+ rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateOIDCOAuthAccount,
+ )
}
// 公开设置(无需认证)
@@ -94,5 +182,6 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
+ authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
}
}
diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go
index 4f411cec..07a66efb 100644
--- a/backend/internal/server/routes/auth_rate_limit_test.go
+++ b/backend/internal/server/routes/auth_rate_limit_test.go
@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
+ "/api/v1/auth/oauth/pending/send-verify-code",
}
for _, path := range paths {
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index cbf98293..5982e1cc 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -88,6 +88,30 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
+ gateway.POST("/images/generations", func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
+ gateway.POST("/images/edits", func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
@@ -124,6 +148,30 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
+ r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
+ r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go
index 4d65a626..87a77cbc 100644
--- a/backend/internal/server/routes/gateway_test.go
+++ b/backend/internal/server/routes/gateway_test.go
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -24,6 +25,11 @@ func newGatewayRoutesTestRouter() *gin.Engine {
OpenAIGateway: &handler.OpenAIGatewayHandler{},
},
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
+ groupID := int64(1)
+ c.Set(string(servermiddleware.ContextKeyAPIKey), &service.APIKey{
+ GroupID: &groupID,
+ Group: &service.Group{Platform: service.PlatformOpenAI},
+ })
c.Next()
}),
nil,
@@ -48,3 +54,21 @@ func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
}
}
+
+func TestGatewayRoutesOpenAIImagesPathsAreRegistered(t *testing.T) {
+ router := newGatewayRoutesTestRouter()
+
+ for _, path := range []string{
+ "/v1/images/generations",
+ "/v1/images/edits",
+ "/images/generations",
+ "/images/edits",
+ } {
+ req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-image-2","prompt":"draw a cat"}`))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ router.ServeHTTP(w, req)
+ require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI images handler", path)
+ }
+}
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index 23bd58ad..e4828ead 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -44,11 +44,13 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
- // Payment result page needs to verify order status without login
- // (user session may have expired during provider redirect).
+ // Signed resume-token recovery is the preferred public lookup path.
+ // The legacy anonymous out_trade_no verify endpoint remains available as a
+ // persisted-state compatibility path for staggered upgrades.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
+ public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken)
}
// --- Webhook endpoints (no auth) ---
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index d004f8b4..b76bb3cd 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -25,6 +25,10 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
+ user.POST("/account-bindings/email", h.User.BindEmailIdentity)
+ user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
+ user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 52db3073..801eac1b 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -121,6 +121,9 @@ func (a *Account) IsSchedulable() bool {
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
+ if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() {
+ return false
+ }
return true
}
@@ -908,6 +911,34 @@ func (a *Account) GetChatGPTAccountID() string {
return a.GetCredential("chatgpt_account_id")
}
+func (a *Account) GetOpenAIDeviceID() string {
+ if !a.IsOpenAIOAuth() {
+ return ""
+ }
+ return strings.TrimSpace(a.GetExtraString("openai_device_id"))
+}
+
+func (a *Account) GetOpenAISessionID() string {
+ if !a.IsOpenAIOAuth() {
+ return ""
+ }
+ return strings.TrimSpace(a.GetExtraString("openai_session_id"))
+}
+
+func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
+ if !a.IsOpenAI() {
+ return false
+ }
+ switch capability {
+ case OpenAIImagesCapabilityBasic:
+ return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
+ case OpenAIImagesCapabilityNative:
+ return a.Type == AccountTypeAPIKey
+ default:
+ return true
+ }
+}
+
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go
new file mode 100644
index 00000000..2895b34c
--- /dev/null
+++ b/backend/internal/service/account_quota_schedulable_test.go
@@ -0,0 +1,123 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) {
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "apikey daily quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey weekly quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_weekly_limit": 50.0,
+ "quota_weekly_used": 50.0,
+ "quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey total quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_limit": 100.0,
+ "quota_used": 100.0,
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey quota not exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 5.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "apikey expired daily period restores schedulable",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "oauth ignores quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "bedrock quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeBedrock,
+ Extra: map[string]any{
+ "quota_limit": 200.0,
+ "quota_used": 200.0,
+ },
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, tt.account.IsSchedulable())
+ })
+ }
+}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index a5559b7d..52d53013 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"crypto/rand"
+ "encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
@@ -52,8 +53,14 @@ type TestEvent struct {
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
+ defaultOpenAIImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
+// isOpenAIImageModel checks if the model is an OpenAI image generation model (e.g. gpt-image-2).
+func isOpenAIImageModel(model string) bool {
+ return strings.HasPrefix(strings.ToLower(model), "gpt-image-")
+}
+
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
@@ -170,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
- return s.testOpenAIAccountConnection(c, account, modelID)
+ return s.testOpenAIAccountConnection(c, account, modelID, prompt)
}
if account.IsGemini() {
@@ -410,8 +417,9 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
-func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
+func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
+ _ = prompt
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
@@ -429,6 +437,18 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
}
+ // Route to image generation test if an image model is selected
+ if isOpenAIImageModel(testModelID) {
+ imagePrompt := strings.TrimSpace(prompt)
+ if imagePrompt == "" {
+ imagePrompt = defaultOpenAIImageTestPrompt
+ }
+ if account.Type == "apikey" {
+ return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt)
+ }
+ return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt)
+ }
+
// Determine authentication method and API URL
var authToken string
var apiURL string
@@ -1025,7 +1045,336 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
}
}
-// sendEvent sends a SSE event to the client
+// testOpenAIImageAPIKey tests OpenAI image generation using an API Key account.
+func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
+ authToken := account.GetOpenAIApiKey()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No API key available")
+ }
+
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.openai.com"
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
+
+ // Set SSE headers
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
+
+ payload := map[string]any{
+ "model": modelID,
+ "prompt": prompt,
+ "n": 1,
+ "response_format": "b64_json",
+ }
+ payloadBytes, _ := json.Marshal(payload)
+
+ req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+authToken)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read response: %s", err.Error()))
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
+ }
+
+ // Parse {"data": [{"b64_json": "...", "revised_prompt": "..."}]}
+ var result struct {
+ Data []struct {
+ B64JSON string `json:"b64_json"`
+ RevisedPrompt string `json:"revised_prompt"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
+ }
+
+ if len(result.Data) == 0 {
+ return s.sendErrorAndEnd(c, "No images returned from API")
+ }
+
+ for _, item := range result.Data {
+ if item.RevisedPrompt != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
+ }
+ if item.B64JSON != "" {
+ s.sendEvent(c, TestEvent{
+ Type: "image",
+ ImageURL: "data:image/png;base64," + item.B64JSON,
+ MimeType: "image/png",
+ })
+ }
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
+// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API.
+func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
+ authToken := account.GetOpenAIAccessToken()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No access token available")
+ }
+
+ // Set SSE headers
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"})
+
+ // Build headers (replicating buildOpenAIBackendAPIHeaders logic)
+ headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ client, err := newOpenAIBackendAPIClient(proxyURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error()))
+ }
+
+ // Bootstrap
+ if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
+ log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr)
+ }
+
+ // Fetch chat requirements
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"})
+ chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error()))
+ }
+ if chatReqs.Arkose.Required {
+ return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required")
+ }
+
+ // Initialize and prepare conversation
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"})
+ parentMessageID := uuid.NewString()
+ proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
+ _ = initializeOpenAIImageConversation(ctx, client, headers)
+ conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error()))
+ }
+
+ // Build simplified conversation request (no file uploads)
+ convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID)
+ convHeaders := cloneHTTPHeader(headers)
+ convHeaders.Set("Accept", "text/event-stream")
+ convHeaders.Set("Content-Type", "application/json")
+ convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
+ if conduitToken != "" {
+ convHeaders.Set("x-conduit-token", conduitToken)
+ }
+ if proofToken != "" {
+ convHeaders.Set("openai-sentinel-proof-token", proofToken)
+ }
+
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"})
+
+ resp, err := client.R().
+ SetContext(ctx).
+ DisableAutoReadResponse().
+ SetHeaders(headerToMap(convHeaders)).
+ SetBodyJsonMarshal(convReq).
+ Post(openAIChatGPTConversationURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error()))
+ }
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if resp.StatusCode >= 400 {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode))
+ }
+
+ startTime := time.Now()
+ conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error()))
+ }
+
+ pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
+ if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"})
+ polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
+ if pollErr != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error()))
+ }
+ pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
+ }
+ pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
+ if len(pointerInfos) == 0 {
+ return s.sendErrorAndEnd(c, "No images returned from conversation")
+ }
+
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"})
+
+ // Download and encode each image
+ for _, pointer := range pointerInfos {
+ downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error()))
+ }
+ data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error()))
+ }
+ b64 := base64.StdEncoding.EncodeToString(data)
+ mimeType := http.DetectContentType(data)
+ if pointer.Prompt != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt})
+ }
+ s.sendEvent(c, TestEvent{
+ Type: "image",
+ ImageURL: "data:" + mimeType + ";base64," + b64,
+ MimeType: mimeType,
+ })
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
+// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes.
+// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without
+// requiring the full gateway service dependency.
+func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header {
+ // Ensure device and session IDs exist
+ deviceID := account.GetOpenAIDeviceID()
+ sessionID := account.GetOpenAISessionID()
+ if deviceID == "" || sessionID == "" {
+ updates := map[string]any{}
+ if deviceID == "" {
+ deviceID = uuid.NewString()
+ updates["openai_device_id"] = deviceID
+ }
+ if sessionID == "" {
+ sessionID = uuid.NewString()
+ updates["openai_session_id"] = sessionID
+ }
+ if account.Extra == nil {
+ account.Extra = map[string]any{}
+ }
+ for key, value := range updates {
+ account.Extra[key] = value
+ }
+ if repo != nil {
+ updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+ _ = repo.UpdateExtra(updateCtx, account.ID, updates)
+ }
+ }
+
+ headers := make(http.Header)
+ headers.Set("Authorization", "Bearer "+token)
+ headers.Set("Accept", "application/json")
+ headers.Set("Origin", "https://chatgpt.com")
+ headers.Set("Referer", "https://chatgpt.com/")
+ headers.Set("Sec-Fetch-Dest", "empty")
+ headers.Set("Sec-Fetch-Mode", "cors")
+ headers.Set("Sec-Fetch-Site", "same-origin")
+ headers.Set("User-Agent", openAIImageBackendUserAgent)
+ if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
+ headers.Set("User-Agent", customUA)
+ }
+ if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
+ headers.Set("chatgpt-account-id", chatgptAccountID)
+ }
+ if deviceID != "" {
+ headers.Set("oai-device-id", deviceID)
+ headers.Set("Cookie", "oai-did="+deviceID)
+ }
+ if sessionID != "" {
+ headers.Set("oai-session-id", sessionID)
+ }
+ return headers
+}
+
+// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request.
+func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any {
+ promptText := strings.TrimSpace(prompt)
+ if promptText == "" {
+ promptText = "Generate an image."
+ }
+ metadata := map[string]any{
+ "developer_mode_connector_ids": []any{},
+ "selected_github_repos": []any{},
+ "selected_all_github_repos": false,
+ "system_hints": []string{"picture_v2"},
+ "serialization_metadata": map[string]any{
+ "custom_symbol_offsets": []any{},
+ },
+ }
+ message := map[string]any{
+ "id": uuid.NewString(),
+ "author": map[string]any{"role": "user"},
+ "content": map[string]any{
+ "content_type": "text",
+ "parts": []any{promptText},
+ },
+ "metadata": metadata,
+ "create_time": float64(time.Now().UnixMilli()) / 1000,
+ }
+ return map[string]any{
+ "action": "next",
+ "client_prepare_state": "sent",
+ "parent_message_id": parentMessageID,
+ "messages": []any{message},
+ "model": "auto",
+ "timezone_offset_min": openAITimezoneOffsetMinutes(),
+ "timezone": openAITimezoneName(),
+ "conversation_mode": map[string]any{"kind": "primary_assistant"},
+ "system_hints": []string{"picture_v2"},
+ "supports_buffering": true,
+ "supported_encodings": []string{"v1"},
+ "client_contextual_info": map[string]any{"app_name": "chatgpt.com"},
+ "force_nulligen": false,
+ "force_paragen": false,
+ "force_paragen_model_slug": "",
+ "force_rate_limit": false,
+ "websocket_request_id": uuid.NewString(),
+ }
+}
+
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index 82606979..82ff0a8b 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -103,7 +103,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
@@ -134,7 +134,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7c26a47c..4ae66613 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -11,6 +12,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -33,6 +36,7 @@ type AdminService interface {
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
+ BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -127,6 +131,44 @@ type UpdateUserInput struct {
GroupRates map[int64]*float64
}
+type AdminBindAuthIdentityInput struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+ Issuer *string
+ Metadata map[string]any
+ Channel *AdminBindAuthIdentityChannelInput
+}
+
+type AdminBindAuthIdentityChannelInput struct {
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+ Metadata map[string]any
+}
+
+type AdminBoundAuthIdentity struct {
+ UserID int64 `json:"user_id"`
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ Issuer *string `json:"issuer,omitempty"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
+}
+
+type AdminBoundAuthIdentityChannel struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
type CreateGroupInput struct {
Name string
Description string
@@ -491,6 +533,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
+ if len(users) > 0 {
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
+ }
+ lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr)
+ } else {
+ for i := range users {
+ users[i].LastUsedAt = lastUsedByUserID[users[i].ID]
+ }
+ }
+ }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
@@ -535,6 +591,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if err != nil {
return nil, err
}
+ lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr)
+ } else {
+ user.LastUsedAt = lastUsedAt
+ }
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
@@ -586,6 +648,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
+ // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
+ if input.GroupRates != nil {
+ for groupID, rate := range input.GroupRates {
+ if rate != nil && *rate <= 0 {
+ return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
+ }
+ }
+ }
+
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -788,6 +859,334 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
return codes, result.Total, totalRecharged, nil
}
+func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
+ }
+ if s == nil || s.entClient == nil || s.userRepo == nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ return nil, err
+ }
+
+ providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
+ providerKey := strings.TrimSpace(input.ProviderKey)
+ providerSubject := strings.TrimSpace(input.ProviderSubject)
+ if providerType == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
+ }
+ if providerKey == "" || providerSubject == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
+ }
+ canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
+ compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
+
+ var issuer *string
+ if input.Issuer != nil {
+ trimmed := strings.TrimSpace(*input.Issuer)
+ if trimmed != "" {
+ issuer = &trimmed
+ }
+ }
+
+ channelInput := normalizeAdminBindChannelInput(input.Channel)
+ if input.Channel != nil && channelInput == nil {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
+ }
+
+ verifiedAt := time.Now().UTC()
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ identityRecords, err := tx.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(compatibleProviderKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
+
+ if identity == nil {
+ create := tx.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(canonicalProviderKey).
+ SetProviderSubject(providerSubject).
+ SetVerifiedAt(verifiedAt)
+ if issuer != nil {
+ create = create.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentity.UpdateOneID(identity.ID).
+ SetVerifiedAt(verifiedAt).
+ SetProviderKey(canonicalProviderKey)
+ if issuer != nil {
+ update = update.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if channelInput != nil {
+ channelRecords, err := tx.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
+ authidentitychannel.ChannelEQ(channelInput.Channel),
+ authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
+ if channel == nil {
+ create := tx.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(canonicalProviderKey).
+ SetChannel(channelInput.Channel).
+ SetChannelAppID(channelInput.ChannelAppID).
+ SetChannelSubject(channelInput.ChannelSubject)
+ if channelInput.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID).
+ SetProviderKey(canonicalProviderKey)
+ if channelInput.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
+ }
+ return buildAdminBoundAuthIdentity(identity, channel), nil
+}
+
+func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return []string{providerKey}
+ }
+ if providerType != "wechat" {
+ return []string{providerKey}
+ }
+
+ keys := []string{providerKey}
+ if !strings.EqualFold(providerKey, "wechat-main") {
+ keys = append(keys, "wechat-main")
+ }
+ if !strings.EqualFold(providerKey, "wechat") {
+ keys = append(keys, "wechat")
+ }
+ return keys
+}
+
+func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ existingKey = strings.TrimSpace(existingKey)
+ requestedKey = strings.TrimSpace(requestedKey)
+ if providerType != "wechat" {
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+ }
+ if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
+ return "wechat-main"
+ }
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+}
+
+func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerType != "wechat" {
+ return 0
+ }
+ switch {
+ case strings.EqualFold(providerKey, "wechat-main"):
+ return 0
+ case strings.EqualFold(providerKey, "wechat"):
+ return 2
+ default:
+ return 1
+ }
+}
+
+func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
+ var selected *dbent.AuthIdentity
+ for _, record := range records {
+ if record.UserID != userID {
+ continue
+ }
+ if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
+ for _, record := range records {
+ if record.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
+ var selected *dbent.AuthIdentityChannel
+ for _, record := range records {
+ if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
+ continue
+ }
+ if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
+ for _, record := range records {
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
+ if input == nil {
+ return nil
+ }
+ channel := &AdminBindAuthIdentityChannelInput{
+ Channel: strings.TrimSpace(input.Channel),
+ ChannelAppID: strings.TrimSpace(input.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(input.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
+ }
+ if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
+ return nil
+ }
+ return channel
+}
+
+func normalizeAdminAuthIdentityProviderType(input string) string {
+ switch strings.ToLower(strings.TrimSpace(input)) {
+ case "email":
+ return "email"
+ case "linuxdo":
+ return "linuxdo"
+ case "oidc":
+ return "oidc"
+ case "wechat":
+ return "wechat"
+ default:
+ return ""
+ }
+}
+
+func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
+ if identity == nil {
+ return nil
+ }
+ result := &AdminBoundAuthIdentity{
+ UserID: identity.UserID,
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ }
+ if channel != nil {
+ result.Channel = &AdminBoundAuthIdentityChannel{
+ Channel: strings.TrimSpace(channel.Channel),
+ ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
+ CreatedAt: channel.CreatedAt,
+ UpdatedAt: channel.UpdatedAt,
+ }
+ }
+ return result
+}
+
+func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
+ if input == nil {
+ return nil
+ }
+ if len(input) == 0 {
+ return map[string]any{}
+ }
+ data, err := json.Marshal(input)
+ if err != nil {
+ out := make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ return out
+ }
+ var out map[string]any
+ if err := json.Unmarshal(data, &out); err != nil {
+ out = make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ }
+ return out
+}
+
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
@@ -811,6 +1210,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
+ if input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
+
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
@@ -1050,6 +1453,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
+ if *input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
@@ -1286,6 +1692,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
if s.userGroupRateRepo == nil {
return nil
}
+ for _, e := range entries {
+ if e.RateMultiplier <= 0 {
+ return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
+ }
+ }
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index 419ddbc3..fcde5cbf 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
@@ -70,6 +79,23 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected")
+}
+
+func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected")
+}
+
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go
new file mode 100644
index 00000000..719199f2
--- /dev/null
+++ b/backend/internal/service/admin_service_auth_identity_binding_test.go
@@ -0,0 +1,302 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-123",
+ Metadata: map[string]any{"scene": "migration"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, user.ID, result.UserID)
+ require.Equal(t, "wechat", result.ProviderType)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.VerifiedAt)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+ require.NotNil(t, identity.VerifiedAt)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, identity.ID, channel.IdentityID)
+ require.Equal(t, "migration", channel.Metadata["scene"])
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ target, err := client.User.Create().
+ SetEmail("target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("subject-1").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ })
+ require.Error(t, err)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
+}
+
+func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("same-user@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "first"},
+ })
+ require.NoError(t, err)
+
+ second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "second"},
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.UserID, second.UserID)
+ require.Equal(t, "second", second.Metadata["source"])
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("subject-2"),
+ ).
+ All(ctx)
+ require.NoError(t, err)
+ require.Len(t, identities, 1)
+ require.Equal(t, "second", identities[0].Metadata["source"])
+}
+
+func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("wechat-alias@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyChannel, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("open").
+ SetChannelAppID("wx-open").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-legacy-123",
+ Metadata: map[string]any{"scene": "admin-repair"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", identity.ProviderKey)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+
+ channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", channel.ProviderKey)
+ require.Equal(t, legacyIdentity.ID, channel.IdentityID)
+ require.Equal(t, "admin-repair", channel.Metadata["scene"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, channelCount)
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("invalid-provider@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "github",
+ ProviderKey: "github-main",
+ ProviderSubject: "subject-3",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
+}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index fbc856cf..fe9e7701 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -13,15 +13,18 @@ import (
)
type userRepoStub struct {
- user *User
- getErr error
- createErr error
- deleteErr error
- exists bool
- existsErr error
- nextID int64
- created []*User
- deletedIDs []int64
+ user *User
+ getErr error
+ createErr error
+ deleteErr error
+ exists bool
+ existsErr error
+ nextID int64
+ created []*User
+ updated []*User
+ deletedIDs []int64
+ usersByEmail map[string]*User
+ getByEmailErr error
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID
}
s.created = append(s.created, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
return nil
}
@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
- panic("unexpected GetByEmail call")
+ if s.getByEmailErr != nil {
+ return nil, s.getByEmailErr
+ }
+ if s.usersByEmail != nil {
+ if user, ok := s.usersByEmail[email]; ok {
+ return user, nil
+ }
+ }
+ if s.user != nil && s.user.Email == email {
+ return s.user, nil
+ }
+ return nil, ErrUserNotFound
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
- panic("unexpected Update call")
+ s.updated = append(s.updated, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
+ return nil
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
@@ -62,6 +87,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
+func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ panic("unexpected GetUserAvatar call")
+}
+
+func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -70,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -101,6 +150,14 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
+func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
new file mode 100644
index 00000000..2232c9c3
--- /dev/null
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -0,0 +1,187 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type ensureEmailCall struct {
+ userID int64
+ email string
+}
+
+type replaceEmailCall struct {
+ userID int64
+ oldEmail string
+ newEmail string
+}
+
+type emailSyncRepoStub struct {
+ user *User
+ nextID int64
+ updateCalls int
+ created []*User
+ updated []*User
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+ ensureErr error
+ replaceErr error
+}
+
+func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
+ if s.nextID != 0 && user.ID == 0 {
+ user.ID = s.nextID
+ }
+ s.created = append(s.created, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ if s.user == nil {
+ return nil, ErrUserNotFound
+ }
+ cloned := *s.user
+ return &cloned, nil
+}
+
+func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) {
+ return nil, ErrUserNotFound
+}
+
+func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) {
+ return nil, fmt.Errorf("unexpected GetFirstAdmin call")
+}
+
+func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error {
+ s.updateCalls++
+ s.updated = append(s.updated, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected GetUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ return fmt.Errorf("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected List call")
+}
+
+func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected ListWithFilters call")
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+
+func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return s.ensureErr
+}
+
+func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return s.replaceErr
+}
+
+func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ nextID: 55,
+ ensureErr: fmt.Errorf("unexpected email resync"),
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ user, err := svc.CreateUser(context.Background(), &CreateUserInput{
+ Email: "admin-created@example.com",
+ Password: "strong-pass",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, int64(55), user.ID)
+ require.Empty(t, repo.ensureCalls)
+ require.Empty(t, repo.replaceCalls)
+}
+
+func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
+ },
+ replaceErr: fmt.Errorf("unexpected email resync"),
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{
+ Email: "after@example.com",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, "after@example.com", updated.Email)
+ require.Empty(t, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index a4c6d0ca..41d2c26a 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go
index ceeb52c2..657616c4 100644
--- a/backend/internal/service/admin_service_list_users_test.go
+++ b/backend/internal/service/admin_service_list_users_test.go
@@ -6,6 +6,7 @@ import (
"context"
"errors"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
@@ -16,6 +17,8 @@ type userRepoStubForListUsers struct {
users []User
err error
listWithFiltersParams pagination.PaginationParams
+ lastUsedByUserID map[int64]*time.Time
+ lastUsedErr error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
@@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag
}, nil
}
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ result := make(map[int64]*time.Time, len(userIDs))
+ for _, userID := range userIDs {
+ if ts, ok := s.lastUsedByUserID[userID]; ok {
+ result[userID] = ts
+ }
+ }
+ return result, nil
+}
+
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ return s.lastUsedByUserID[userID], nil
+}
+
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
@@ -130,3 +153,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
+
+func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) {
+ lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second)
+ userRepo := &userRepoStubForListUsers{
+ users: []User{{ID: 101, Email: "u@example.com"}},
+ lastUsedByUserID: map[int64]*time.Time{
+ 101: &lastUsed,
+ },
+ }
+ svc := &adminServiceImpl{userRepo: userRepo}
+
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
+ require.NoError(t, err)
+ require.Equal(t, int64(1), total)
+ require.Len(t, users, 1)
+ require.NotNil(t, users[0].LastUsedAt)
+ require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second)
+}
diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go
index 25c66eb4..02741d37 100644
--- a/backend/internal/service/announcement.go
+++ b/backend/internal/service/announcement.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -34,8 +35,23 @@ const (
)
var (
- ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
- ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+ ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
+ ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+ ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required")
+ ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid")
+ ErrAnnouncementContentRequired = infraerrors.BadRequest(
+ "ANNOUNCEMENT_CONTENT_REQUIRED",
+ "announcement content is required",
+ )
+ ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid")
+ ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest(
+ "ANNOUNCEMENT_NOTIFY_MODE_INVALID",
+ "announcement notify_mode is invalid",
+ )
+ ErrAnnouncementInvalidSchedule = infraerrors.BadRequest(
+ "ANNOUNCEMENT_TIME_RANGE_INVALID",
+ "starts_at must be before ends_at",
+ )
)
type AnnouncementTargeting = domain.AnnouncementTargeting
diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go
index c0a0681a..12479041 100644
--- a/backend/internal/service/announcement_service.go
+++ b/backend/internal/service/announcement_service.go
@@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct {
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
- return nil, fmt.Errorf("create announcement: nil input")
+ return nil, ErrAnnouncementNilInput
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
- return nil, fmt.Errorf("create announcement: invalid title")
+ return nil, ErrAnnouncementInvalidTitle
}
if content == "" {
- return nil, fmt.Errorf("create announcement: content is required")
+ return nil, ErrAnnouncementContentRequired
}
status := strings.TrimSpace(input.Status)
@@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
- return nil, fmt.Errorf("create announcement: invalid status")
+ return nil, ErrAnnouncementInvalidStatus
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
@@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
- return nil, fmt.Errorf("create announcement: invalid notify_mode")
+ return nil, ErrAnnouncementInvalidNotifyMode
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
- return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
+ return nil, ErrAnnouncementInvalidSchedule
}
}
@@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
- return nil, fmt.Errorf("update announcement: nil input")
+ return nil, ErrAnnouncementNilInput
}
a, err := s.announcementRepo.GetByID(ctx, id)
@@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
- return nil, fmt.Errorf("update announcement: invalid title")
+ return nil, ErrAnnouncementInvalidTitle
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
- return nil, fmt.Errorf("update announcement: content is required")
+ return nil, ErrAnnouncementContentRequired
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
- return nil, fmt.Errorf("update announcement: invalid status")
+ return nil, ErrAnnouncementInvalidStatus
}
a.Status = status
}
@@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
- return nil, fmt.Errorf("update announcement: invalid notify_mode")
+ return nil, ErrAnnouncementInvalidNotifyMode
}
a.NotifyMode = notifyMode
}
@@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
- return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
+ return nil, ErrAnnouncementInvalidSchedule
}
}
diff --git a/backend/internal/service/announcement_service_test.go b/backend/internal/service/announcement_service_test.go
new file mode 100644
index 00000000..77fb9896
--- /dev/null
+++ b/backend/internal/service/announcement_service_test.go
@@ -0,0 +1,81 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type announcementRepoStub struct {
+ item *Announcement
+}
+
+func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error {
+ s.item = a
+ return nil
+}
+
+func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) {
+ if s.item == nil {
+ return nil, ErrAnnouncementNotFound
+ }
+ return s.item, nil
+}
+
+func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error {
+ s.item = a
+ return nil
+}
+
+func (*announcementRepoStub) Delete(context.Context, int64) error {
+ return nil
+}
+
+func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+
+func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) {
+ return nil, nil
+}
+
+func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) {
+ repo := &announcementRepoStub{}
+ svc := NewAnnouncementService(repo, nil, nil, nil)
+ now := time.Unix(1776790020, 0)
+
+ _, err := svc.Create(context.Background(), &CreateAnnouncementInput{
+ Title: "公告",
+ Content: "内容",
+ Status: AnnouncementStatusActive,
+ NotifyMode: AnnouncementNotifyModePopup,
+ StartsAt: &now,
+ EndsAt: &now,
+ })
+ require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
+}
+
+func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) {
+ repo := &announcementRepoStub{
+ item: &Announcement{
+ ID: 1,
+ Title: "公告",
+ Content: "内容",
+ Status: AnnouncementStatusActive,
+ NotifyMode: AnnouncementNotifyModePopup,
+ },
+ }
+ svc := NewAnnouncementService(repo, nil, nil, nil)
+ now := time.Unix(1776790020, 0)
+ startsAt := &now
+ endsAt := &now
+
+ _, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{
+ StartsAt: &startsAt,
+ EndsAt: &endsAt,
+ })
+ require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
+}
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
new file mode 100644
index 00000000..78f1185d
--- /dev/null
+++ b/backend/internal/service/auth_email_binding.go
@@ -0,0 +1,319 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+// BindEmailIdentity verifies and binds a local email/password identity to the
+// current user, or replaces the existing bound primary email.
+func (s *AuthService) BindEmailIdentity(
+ ctx context.Context,
+ userID int64,
+ email string,
+ verifyCode string,
+ password string,
+) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return nil, err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return nil, ErrEmailReserved
+ }
+ if strings.TrimSpace(password) == "" {
+ return nil, ErrPasswordRequired
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
+ return nil, err
+ }
+
+ currentUser, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && len(password) < 6 {
+ return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters")
+ }
+ if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) {
+ return nil, ErrPasswordIncorrect
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return nil, ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return nil, ErrServiceUnavailable
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ if s.entClient != nil {
+ if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
+ return nil, err
+ }
+ s.revokeEmailIdentitySessions(ctx, userID)
+ return currentUser, nil
+ }
+
+ currentUser.Email = normalizedEmail
+ currentUser.PasswordHash = hashedPassword
+ if err := s.userRepo.Update(ctx, currentUser); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, ErrEmailExists
+ }
+ return nil, ErrServiceUnavailable
+ }
+
+ if firstRealEmailBind {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
+ return nil, fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ s.revokeEmailIdentitySessions(ctx, userID)
+ return currentUser, nil
+}
+
+// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
+func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
+ if s == nil {
+ return ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return ErrEmailReserved
+ }
+ if s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return ErrUserNotFound
+ }
+ return ErrServiceUnavailable
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
+}
+
+func normalizeEmailForIdentityBinding(email string) (string, error) {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || len(normalized) > 255 {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ if _, err := mail.ParseAddress(normalized); err != nil {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ return normalized, nil
+}
+
+func hasBindableEmailIdentitySubject(email string) bool {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ return normalized != "" && !isReservedEmail(normalized)
+}
+
+func (s *AuthService) updateBoundEmailIdentityTx(
+ ctx context.Context,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+ applyFirstBindDefaults bool,
+) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ if err := tx.Commit(); err != nil {
+ return ErrServiceUnavailable
+ }
+ return nil
+}
+
+func (s *AuthService) updateBoundEmailIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+ applyFirstBindDefaults bool,
+) error {
+ if client == nil || currentUser == nil || currentUser.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ oldEmail := currentUser.Email
+ if _, err := client.User.UpdateOneID(currentUser.ID).
+ SetEmail(email).
+ SetPasswordHash(hashedPassword).
+ Save(ctx); err != nil {
+ if dbent.IsConstraintError(err) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if applyFirstBindDefaults {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ updatedUser, err := client.User.Get(ctx, currentUser.ID)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ currentUser.Email = updatedUser.Email
+ currentUser.PasswordHash = updatedUser.PasswordHash
+ currentUser.Balance = updatedUser.Balance
+ currentUser.Concurrency = updatedUser.Concurrency
+ currentUser.UpdatedAt = updatedUser.UpdatedAt
+ return nil
+}
+
+func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) {
+ if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err)
+ }
+}
+
+func replaceBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ oldEmail string,
+ newEmail string,
+ source string,
+) error {
+ newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
+ if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func ensureBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ subject string,
+ source string,
+) error {
+ if client == nil || userID <= 0 || subject == "" {
+ return nil
+ }
+
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_email_bind"
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if !isSQLNoRowsError(err) {
+ return err
+ }
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrEmailExists
+ }
+ return nil
+}
+
+func normalizeBoundEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || isReservedEmail(normalized) {
+ return ""
+ }
+ return normalized
+}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
new file mode 100644
index 00000000..a18cf39c
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -0,0 +1,385 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
+)
+
+func normalizeOAuthSignupSource(signupSource string) string {
+ signupSource = strings.TrimSpace(strings.ToLower(signupSource))
+ switch signupSource {
+ case "", "email":
+ return "email"
+ case "linuxdo", "wechat", "oidc":
+ return signupSource
+ default:
+ return "email"
+ }
+}
+
+// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
+// account-creation flows without relying on the public registration gate.
+func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" {
+ return nil, ErrEmailVerifyRequired
+ }
+ if _, err := mail.ParseAddress(email); err != nil {
+ return nil, ErrEmailVerifyRequired
+ }
+ if isReservedEmail(email) {
+ return nil, ErrEmailReserved
+ }
+ if s == nil || s.emailService == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
+ return nil, err
+ }
+ return &SendVerifyCodeResult{
+ Countdown: int(verifyCodeCooldown / time.Second),
+ }, nil
+}
+
+func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil, nil
+ }
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" {
+ return nil, ErrInvitationCodeRequired
+ }
+
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return nil, ErrInvitationCodeInvalid
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ return nil, ErrInvitationCodeInvalid
+ }
+ return redeemCode, nil
+}
+
+// VerifyOAuthEmailCode verifies the locally entered email verification code for
+// third-party signup and binding flows. This is intentionally independent from
+// the global registration email verification toggle.
+func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
+ email = strings.TrimSpace(strings.ToLower(email))
+ verifyCode = strings.TrimSpace(verifyCode)
+
+ if email == "" {
+ return ErrEmailVerifyRequired
+ }
+ if verifyCode == "" {
+ return ErrEmailVerifyRequired
+ }
+ if s == nil || s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ return s.emailService.VerifyCode(ctx, email, verifyCode)
+}
+
+// RegisterOAuthEmailAccount creates a local account from a third-party first
+// login after the user has verified a local email address.
+func (s *AuthService) RegisterOAuthEmailAccount(
+ ctx context.Context,
+ email string,
+ password string,
+ verifyCode string,
+ invitationCode string,
+ signupSource string,
+) (*TokenPair, *User, error) {
+ if s == nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ return nil, nil, ErrRegDisabled
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if isReservedEmail(email) {
+ return nil, nil, ErrEmailReserved
+ }
+ if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
+ return nil, nil, err
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
+ return nil, nil, err
+ }
+
+ if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
+ return nil, nil, err
+ }
+
+ existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
+ if err != nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if existsEmail {
+ return nil, nil, ErrEmailExists
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+
+ user := &User{
+ Email: email,
+ PasswordHash: hashedPassword,
+ Role: RoleUser,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ Status: StatusActive,
+ SignupSource: signupSource,
+ }
+
+ if err := s.userRepo.Create(ctx, user); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, nil, ErrEmailExists
+ }
+ return nil, nil, ErrServiceUnavailable
+ }
+
+ tokenPair, err := s.GenerateTokenPair(ctx, user, "")
+ if err != nil {
+ _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
+ return nil, nil, fmt.Errorf("generate token pair: %w", err)
+ }
+ return tokenPair, user, nil
+}
+
+// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
+// only after the pending OAuth flow has fully reached its last reversible step.
+func (s *AuthService) FinalizeOAuthEmailAccount(
+ ctx context.Context,
+ user *User,
+ invitationCode string,
+ signupSource string,
+) error {
+ if s == nil || user == nil || user.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return err
+ }
+ if invitationRedeemCode != nil {
+ if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ return ErrInvitationCodeInvalid
+ }
+ }
+
+ s.updateOAuthSignupSource(ctx, user.ID, signupSource)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ return nil
+}
+
+// RollbackOAuthEmailAccountCreation removes a partially-created local account
+// and restores any invitation code already consumed by that account.
+func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return ErrServiceUnavailable
+ }
+ if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
+ return err
+ }
+ if err := s.userRepo.Delete(ctx, userID); err != nil {
+ return fmt.Errorf("delete created oauth user: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil
+ }
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
+ return ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" || userID <= 0 {
+ return nil
+ }
+
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ if errors.Is(err, ErrRedeemCodeNotFound) {
+ return nil
+ }
+ return fmt.Errorf("load invitation code: %w", err)
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
+ return nil
+ }
+
+ redeemCode.Status = StatusUnused
+ redeemCode.UsedBy = nil
+ redeemCode.UsedAt = nil
+ if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
+ return fmt.Errorf("restore invitation code: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
+ if s == nil || s.entClient == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return s.entClient
+}
+
+func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ return &RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: oauthEmailFlowStringValue(entity.Notes),
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+ }
+ return s.redeemRepo.GetByCode(ctx, invitationCode)
+}
+
+func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ affected, err := client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
+ SetStatus(StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return ErrRedeemCodeUsed
+ }
+ return nil
+ }
+ return s.redeemRepo.Use(ctx, invitationID, userID)
+}
+
+func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ update := client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+ }
+ return s.redeemRepo.Update(ctx, code)
+}
+
+func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
+ client := s.oauthEmailFlowClient(ctx)
+ if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
+}
+
+func oauthEmailFlowStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
+
+// ValidatePasswordCredentials checks the local password without completing the
+// login flow. This is used by pending third-party account adoption flows before
+// the external identity has been bound.
+func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
+ if err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return nil, ErrInvalidCredentials
+ }
+ return nil, ErrServiceUnavailable
+ }
+ if !user.IsActive() {
+ return nil, ErrUserNotActive
+ }
+ if !s.CheckPassword(password, user.PasswordHash) {
+ return nil, ErrInvalidCredentials
+ }
+ return user, nil
+}
+
+// RecordSuccessfulLogin updates last-login activity after a non-standard login
+// flow finishes with a real session.
+func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
+ if s != nil && s.userRepo != nil && userID > 0 {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err == nil && user != nil && !isReservedEmail(user.Email) {
+ s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
+ }
+ }
+ s.touchUserLogin(ctx, userID)
+}
diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go
new file mode 100644
index 00000000..e3fb2f85
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow_test.go
@@ -0,0 +1,325 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type redeemCodeRepoStub struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+ updateCalls []*RedeemCode
+}
+
+func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ if s.codesByCode == nil {
+ return nil, ErrRedeemCodeNotFound
+ }
+ redeemCode, ok := s.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ cloned := *code
+ s.updateCalls = append(s.updateCalls, &cloned)
+ if s.codesByCode == nil {
+ s.codesByCode = make(map[string]*RedeemCode)
+ }
+ s.codesByCode[cloned.Code] = &cloned
+ return nil
+}
+
+func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range s.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ s.codesByCode[code] = redeemCode
+ s.useCalls = append(s.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func newOAuthEmailFlowAuthService(
+ userRepo UserRepository,
+ redeemRepo RedeemCodeRepository,
+ refreshTokenCache RefreshTokenCache,
+ settings map[string]string,
+ emailCache EmailCache,
+) *AuthService {
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
+ emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
+
+ return NewAuthService(
+ nil,
+ userRepo,
+ redeemRepo,
+ refreshTokenCache,
+ cfg,
+ settingService,
+ emailService,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+}
+
+func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUnused,
+ },
+ },
+ }
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ nil,
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "INVITE123",
+ "oidc",
+ )
+
+ require.Nil(t, tokenPair)
+ require.Nil(t, user)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "generate token pair")
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, userRepo.created, 1)
+ require.Empty(t, redeemRepo.useCalls)
+ require.Empty(t, redeemRepo.updateCalls)
+}
+
+func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "",
+ " OIDC ",
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Len(t, userRepo.created, 1)
+ require.Equal(t, "oidc", userRepo.created[0].SignupSource)
+}
+
+func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 43}
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fallback@example.com",
+ "secret-123",
+ "246810",
+ "",
+ "github",
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Len(t, userRepo.created, 1)
+ require.Equal(t, "email", userRepo.created[0].SignupSource)
+}
+
+func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
+ userRepo := &userRepoStub{}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUsed,
+ UsedBy: func() *int64 {
+ v := int64(42)
+ return &v
+ }(),
+ UsedAt: func() *time.Time {
+ v := time.Now().UTC()
+ return &v
+ }(),
+ },
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
+
+ require.NoError(t, err)
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, redeemRepo.updateCalls, 1)
+ require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
+}
+
+func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
+ userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "delete created oauth user")
+}
diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go
new file mode 100644
index 00000000..aa06e59f
--- /dev/null
+++ b/backend/internal/service/auth_oauth_first_bind.go
@@ -0,0 +1,104 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+
+ entsql "entgo.io/ent/dialect/sql"
+)
+
+// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap
+// settings the first time a user binds a third-party identity. The grant is
+// idempotent per user/provider pair.
+func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+) error {
+ if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 {
+ return nil
+ }
+
+ if dbent.TxFromContext(ctx) != nil {
+ return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin first bind defaults transaction: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+) error {
+ providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
+ if err != nil {
+ return fmt.Errorf("load auth source defaults: %w", err)
+ }
+ if !enabled {
+ return nil
+ }
+
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ var result entsql.Result
+ if err := client.Driver().Exec(
+ ctx,
+ `INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ []any{userID, strings.TrimSpace(providerType), "first_bind"},
+ &result,
+ ); err != nil {
+ return fmt.Errorf("record first bind provider grant: %w", err)
+ }
+
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return fmt.Errorf("read first bind provider grant result: %w", err)
+ }
+ if affected == 0 {
+ return nil
+ }
+
+ if providerDefaults.Balance != 0 {
+ if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil {
+ return fmt.Errorf("apply first bind balance default: %w", err)
+ }
+ }
+ if providerDefaults.Concurrency != 0 {
+ if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil {
+ return fmt.Errorf("apply first bind concurrency default: %w", err)
+ }
+ }
+ if s.defaultSubAssigner != nil {
+ for _, item := range providerDefaults.Subscriptions {
+ if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
+ UserID: userID,
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ Notes: "auto assigned by first bind defaults",
+ }); err != nil {
+ return fmt.Errorf("apply first bind subscription default: %w", err)
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go
new file mode 100644
index 00000000..6e69c121
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service.go
@@ -0,0 +1,543 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "hash/fnv"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "entgo.io/ent/dialect"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+
+ entsql "entgo.io/ent/dialect/sql"
+)
+
+var (
+ ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
+ ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
+ ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
+ ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
+ ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
+ ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
+ ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
+)
+
+const (
+ defaultPendingAuthTTL = 15 * time.Minute
+ defaultPendingAuthCompletionTTL = 5 * time.Minute
+)
+
+type PendingAuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type CreatePendingAuthSessionInput struct {
+ SessionToken string
+ Intent string
+ Identity PendingAuthIdentityKey
+ TargetUserID *int64
+ RedirectTo string
+ ResolvedEmail string
+ RegistrationPasswordHash string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ LocalFlowState map[string]any
+ ExpiresAt time.Time
+}
+
+type IssuePendingAuthCompletionCodeInput struct {
+ PendingAuthSessionID int64
+ BrowserSessionKey string
+ TTL time.Duration
+}
+
+type IssuePendingAuthCompletionCodeResult struct {
+ Code string
+ ExpiresAt time.Time
+}
+
+type PendingIdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type AuthPendingIdentityService struct {
+ entClient *dbent.Client
+}
+
+var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
+
+type authPendingIdentityScopedKeyLockRegistry struct {
+ mu sync.Mutex
+ locks map[string]*authPendingIdentityScopedKeyLockEntry
+}
+
+type authPendingIdentityScopedKeyLockEntry struct {
+ mu sync.Mutex
+ refs int
+}
+
+func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
+ return &authPendingIdentityScopedKeyLockRegistry{
+ locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
+ }
+}
+
+func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
+ normalized := normalizeAuthPendingIdentityLockKeys(keys...)
+ if len(normalized) == 0 {
+ return func() {}
+ }
+
+ entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
+ r.mu.Lock()
+ for _, key := range normalized {
+ entry := r.locks[key]
+ if entry == nil {
+ entry = &authPendingIdentityScopedKeyLockEntry{}
+ r.locks[key] = entry
+ }
+ entry.refs++
+ entries = append(entries, entry)
+ }
+ r.mu.Unlock()
+
+ for _, entry := range entries {
+ entry.mu.Lock()
+ }
+
+ return func() {
+ for i := len(entries) - 1; i >= 0; i-- {
+ entries[i].mu.Unlock()
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for idx, key := range normalized {
+ entry := entries[idx]
+ entry.refs--
+ if entry.refs == 0 {
+ delete(r.locks, key)
+ }
+ }
+ }
+}
+
+func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
+ if len(keys) == 0 {
+ return nil
+ }
+
+ deduped := make(map[string]struct{}, len(keys))
+ for _, key := range keys {
+ trimmed := strings.TrimSpace(key)
+ if trimmed == "" {
+ continue
+ }
+ deduped[trimmed] = struct{}{}
+ }
+ if len(deduped) == 0 {
+ return nil
+ }
+
+ normalized := make([]string, 0, len(deduped))
+ for key := range deduped {
+ normalized = append(normalized, key)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func authPendingIdentityAdvisoryLockHash(key string) int64 {
+ hasher := fnv.New64a()
+ _, _ = hasher.Write([]byte(key))
+ return int64(hasher.Sum64())
+}
+
+func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
+ release := authPendingIdentityScopedKeyLocks.lock(keys...)
+ normalized := normalizeAuthPendingIdentityLockKeys(keys...)
+ if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
+ return release, nil
+ }
+
+ for _, key := range normalized {
+ var rows entsql.Rows
+ if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
+ release()
+ return nil, err
+ }
+ _ = rows.Close()
+ }
+
+ return release, nil
+}
+
+func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
+ keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
+ if identityID != nil && *identityID > 0 {
+ keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
+ }
+ return keys
+}
+
+func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
+ return &AuthPendingIdentityService{entClient: entClient}
+}
+
+func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken := strings.TrimSpace(input.SessionToken)
+ if sessionToken == "" {
+ var err error
+ sessionToken, err = randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ expiresAt := input.ExpiresAt.UTC()
+ if expiresAt.IsZero() {
+ expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
+ }
+
+ create := s.entClient.PendingAuthSession.Create().
+ SetSessionToken(sessionToken).
+ SetIntent(strings.TrimSpace(input.Intent)).
+ SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
+ SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
+ SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
+ SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
+ SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
+ SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
+ SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
+ SetExpiresAt(expiresAt)
+ if input.TargetUserID != nil {
+ create = create.SetTargetUserID(*input.TargetUserID)
+ }
+ return create.Save(ctx)
+}
+
+func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+
+ code, err := randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ ttl := input.TTL
+ if ttl <= 0 {
+ ttl = defaultPendingAuthCompletionTTL
+ }
+ expiresAt := time.Now().UTC().Add(ttl)
+
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeHash(hashPendingAuthCode(code)).
+ SetCompletionCodeExpiresAt(expiresAt)
+ if strings.TrimSpace(input.BrowserSessionKey) != "" {
+ update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
+ }
+ if _, err := update.Save(ctx); err != nil {
+ return nil, err
+ }
+
+ return &IssuePendingAuthCompletionCodeResult{
+ Code: code,
+ ExpiresAt: expiresAt,
+ }, nil
+}
+
+func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthCodeInvalid
+ }
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
+}
+
+func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+}
+
+func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+ if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken = strings.TrimSpace(sessionToken)
+ if sessionToken == "" {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) consumeSession(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ browserSessionKey string,
+ expiredErr error,
+ consumedErr error,
+) (*dbent.PendingAuthSession, error) {
+ if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+
+ sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
+ now := time.Now().UTC()
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ Where(
+ pendingauthsession.ConsumedAtIsNil(),
+ pendingauthsession.ExpiresAtGTE(now),
+ pendingauthsession.Or(
+ pendingauthsession.CompletionCodeExpiresAtIsNil(),
+ pendingauthsession.CompletionCodeExpiresAtGTE(now),
+ ),
+ ).
+ SetConsumedAt(now).
+ SetLocalFlowState(sanitizedLocalFlowState).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt()
+ if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
+ update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey))
+ }
+ updated, err := update.Save(ctx)
+ if err == nil {
+ return updated, nil
+ }
+ if !dbent.IsNotFound(err) {
+ return nil, err
+ }
+
+ current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID)
+ if currentErr != nil {
+ if dbent.IsNotFound(currentErr) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, currentErr
+ }
+ if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+ return nil, consumedErr
+}
+
+func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
+ sanitized := copyPendingMap(localFlowState)
+ if len(sanitized) == 0 {
+ return sanitized
+ }
+
+ rawCompletion, ok := sanitized["completion_response"]
+ if !ok {
+ return sanitized
+ }
+ completion, ok := rawCompletion.(map[string]any)
+ if !ok {
+ return sanitized
+ }
+
+ cleanedCompletion := copyPendingMap(completion)
+ for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
+ delete(cleanedCompletion, key)
+ }
+ sanitized["completion_response"] = cleanedCompletion
+ return sanitized
+}
+
+func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
+ if session == nil {
+ return ErrPendingAuthSessionNotFound
+ }
+
+ now := time.Now().UTC()
+ if session.ConsumedAt != nil {
+ return consumedErr
+ }
+ if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
+ return expiredErr
+ }
+ if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
+ return expiredErr
+ }
+ if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return ErrPendingAuthBrowserMismatch
+ }
+ return nil
+}
+
+func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
+ return nil, err
+ }
+
+ client := s.entClient
+ txCtx := ctx
+ if err == nil {
+ defer func() { _ = tx.Rollback() }()
+ client = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
+ } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ client = existingTx.Client()
+ }
+
+ releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
+ if err != nil {
+ return nil, err
+ }
+ defer releaseLocks()
+
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
+ dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
+ col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.NEQ(col, input.PendingAuthSessionID),
+ ))
+ }),
+ ).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return nil, err
+ }
+ }
+
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+
+ decisionID, err := create.
+ OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
+ UpdateNewValues().
+ ID(txCtx)
+ if err != nil {
+ return nil, err
+ }
+
+ decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
+ if err != nil {
+ return nil, err
+ }
+
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return nil, err
+ }
+ }
+
+ return decision, nil
+}
+
+func copyPendingMap(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func randomOpaqueToken(byteLen int) (string, error) {
+ if byteLen <= 0 {
+ byteLen = 16
+ }
+ buf := make([]byte, byteLen)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
+func hashPendingAuthCode(code string) string {
+ sum := sha256.Sum256([]byte(code))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go
new file mode 100644
index 00000000..555bb0e7
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service_test.go
@@ -0,0 +1,526 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return NewAuthPendingIdentityService(client), client
+}
+
+func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("pending-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ TargetUserID: &targetUser.ID,
+ RedirectTo: "/profile",
+ ResolvedEmail: "user@example.com",
+ BrowserSessionKey: "browser-1",
+ UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
+ LocalFlowState: map[string]any{"step": "email_required"},
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, session.SessionToken)
+ require.Equal(t, "bind_current_user", session.Intent)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, targetUser.ID, *session.TargetUserID)
+ require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
+ require.Equal(t, "email_required", session.LocalFlowState["step"])
+}
+
+func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expected",
+ UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
+ LocalFlowState: map[string]any{"step": "pending"},
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expected",
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, issued.Code)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+ require.Empty(t, consumed.CompletionCodeHash)
+ require.Nil(t, consumed.CompletionCodeExpiresAt)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
+}
+
+func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expired",
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expired",
+ TTL: time.Second,
+ })
+ require.NoError(t, err)
+
+ _, err = client.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
+ require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ require.NoError(t, err)
+
+ first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.True(t, first.AdoptDisplayName)
+ require.False(t, first.AdoptAvatar)
+ require.Nil(t, first.IdentityID)
+
+ second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+ require.NotNil(t, second.IdentityID)
+ require.Equal(t, identity.ID, *second.IdentityID)
+ require.True(t, second.AdoptAvatar)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption-reassign@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-reassign").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-reassign",
+ },
+ })
+ require.NoError(t, err)
+
+ firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: firstSession.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, firstDecision.IdentityID)
+ require.Equal(t, identity.ID, *firstDecision.IdentityID)
+
+ secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-reassign",
+ },
+ })
+ require.NoError(t, err)
+
+ secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: secondSession.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, secondDecision.IdentityID)
+ require.Equal(t, identity.ID, *secondDecision.IdentityID)
+
+ reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedFirst.IdentityID)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption-concurrent@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-concurrent").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-concurrent",
+ },
+ })
+ require.NoError(t, err)
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type adoptionResult struct {
+ decision *dbent.IdentityAdoptionDecision
+ err error
+ }
+
+ input := PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ }
+
+ results := make(chan adoptionResult, 2)
+ go func() {
+ decision, err := svc.UpsertAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ decision, err := svc.UpsertAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ require.NoError(t, first.err)
+ require.NoError(t, second.err)
+ require.NotNil(t, first.decision)
+ require.NotNil(t, second.decision)
+ require.Equal(t, first.decision.ID, second.decision.ID)
+
+ count, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+
+ loaded, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, loaded.IdentityID)
+ require.Equal(t, identity.ID, *loaded.IdentityID)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
+ t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
+
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("legacy-null-session@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("legacy-null-session").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.ExecContext(
+ ctx,
+ `INSERT INTO identity_adoption_decisions
+ (identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id)
+ VALUES (?, ?, ?, ?, ?, ?, NULL)`,
+ identity.ID,
+ true,
+ false,
+ time.Now().UTC(),
+ time.Now().UTC(),
+ time.Now().UTC(),
+ )
+ require.NoError(t, err)
+ legacyDecision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.IdentityIDEQ(identity.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, legacyDecision.IdentityID)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "legacy-null-session",
+ },
+ })
+ require.NoError(t, err)
+
+ decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+
+ reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedLegacy.IdentityID)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "subject-session-token",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "token",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "stale-replay-subject",
+ },
+ BrowserSessionKey: "browser-session",
+ })
+ require.NoError(t, err)
+
+ loaded, err := svc.getBrowserSession(ctx, session.SessionToken)
+ require.NoError(t, err)
+
+ consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "legacy-token-subject",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ stored, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+
+ completion, ok := stored.LocalFlowState["completion_response"].(map[string]any)
+ require.True(t, ok)
+ require.NotContains(t, completion, "access_token")
+ require.NotContains(t, completion, "refresh_token")
+ require.NotContains(t, completion, "expires_in")
+ require.NotContains(t, completion, "token_type")
+ require.Equal(t, "/dashboard", completion["redirect"])
+}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index fd28cd42..3bf9da3d 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
+ "encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -13,6 +14,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -77,6 +79,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
+type signupGrantPlan struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+}
+
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
@@ -106,6 +114,13 @@ func NewAuthService(
}
}
+func (s *AuthService) EntClient() *dbent.Client {
+ if s == nil {
+ return nil
+ }
+ return s.entClient
+}
+
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
@@ -179,21 +194,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 获取默认配置
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 创建用户
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -205,7 +214,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, "email", true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -469,22 +479,18 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 新用户默认值。
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
+ SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -501,7 +507,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -520,7 +527,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
-
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
@@ -584,21 +590,18 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
+ SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
@@ -630,7 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -646,7 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -670,7 +675,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
-
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
@@ -678,80 +682,273 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
-// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
-const pendingOAuthTokenTTL = 10 * time.Minute
-
-// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
-const pendingOAuthPurpose = "pending_oauth_registration"
-
-type pendingOAuthClaims struct {
- Email string `json:"email"`
- Username string `json:"username"`
- Purpose string `json:"purpose"`
- jwt.RegisteredClaims
-}
-
-// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
-// while waiting for the user to supply an invitation code.
-func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: email,
- Username: username,
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString([]byte(s.cfg.JWT.Secret))
-}
-
-// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
-// Returns ErrInvalidToken when the token is invalid or expired.
-func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
- if len(tokenStr) > maxTokenLength {
- return "", "", ErrInvalidToken
- }
- parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
- token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
- if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
- }
- return []byte(s.cfg.JWT.Secret), nil
- })
- if parseErr != nil {
- return "", "", ErrInvalidToken
- }
- claims, ok := token.Claims.(*pendingOAuthClaims)
- if !ok || !token.Valid {
- return "", "", ErrInvalidToken
- }
- if claims.Purpose != pendingOAuthPurpose {
- return "", "", ErrInvalidToken
- }
- return claims.Email, claims.Username, nil
-}
-
-func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
- items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
+ Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
+func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
+ plan := signupGrantPlan{}
+ if s != nil && s.cfg != nil {
+ plan.Balance = s.cfg.Default.UserBalance
+ plan.Concurrency = s.cfg.Default.UserConcurrency
+ }
+ if s == nil || s.settingService == nil {
+ return plan
+ }
+
+ plan.Balance = s.settingService.GetDefaultBalance(ctx)
+ plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
+ plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
+
+ resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
+ return plan
+ }
+ if !enabled {
+ return plan
+ }
+
+ plan.Balance = resolved.Balance
+ plan.Concurrency = resolved.Concurrency
+ plan.Subscriptions = resolved.Subscriptions
+ return plan
+}
+
+func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
+ if defaults == nil {
+ return ProviderDefaultGrantSettings{}, false
+ }
+
+ switch strings.ToLower(strings.TrimSpace(signupSource)) {
+ case "email":
+ return defaults.Email, true
+ case "linuxdo":
+ return defaults.LinuxDo, true
+ case "oidc":
+ return defaults.OIDC, true
+ case "wechat":
+ return defaults.WeChat, true
+ default:
+ return ProviderDefaultGrantSettings{}, false
+ }
+}
+
+func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
+ if user == nil || user.ID <= 0 {
+ return
+ }
+
+ if strings.TrimSpace(signupSource) == "" {
+ signupSource = "email"
+ }
+ s.updateUserSignupSource(ctx, user.ID, signupSource)
+
+ if touchLogin {
+ s.touchUserLogin(ctx, user.ID)
+ }
+}
+
+func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ if strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetSignupSource(signupSource).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
+ }
+}
+
+func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ now := time.Now().UTC()
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetLastLoginAt(now).
+ SetLastActiveAt(now).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
+ }
+}
+
+func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
+ if s == nil || user == nil || user.ID <= 0 {
+ return
+ }
+ identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
+ if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
+ }
+ }
+}
+
+func (s *AuthService) shouldApplyEmailFirstBindDefaults(
+ ctx context.Context,
+ userID int64,
+ identity *dbent.AuthIdentity,
+ created bool,
+) bool {
+ source := emailAuthIdentitySource(identity.Metadata)
+ if source == "auth_service_login_backfill" {
+ return false
+ }
+ if created {
+ return true
+ }
+ if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
+ return false
+ }
+ if source != "auth_service_dual_write" {
+ return false
+ }
+
+ hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
+ return false
+ }
+ return !hasGrant
+}
+
+func emailAuthIdentitySource(metadata map[string]any) string {
+ if len(metadata) == 0 {
+ return ""
+ }
+ raw, ok := metadata["source"]
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(fmt.Sprint(raw))
+}
+
+func (s *AuthService) hasProviderGrantRecord(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+ grantReason string,
+) (bool, error) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return false, nil
+ }
+
+ rows, err := s.entClient.QueryContext(
+ ctx,
+ `SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`,
+ userID,
+ strings.TrimSpace(providerType),
+ strings.TrimSpace(grantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ defer func() { _ = rows.Close() }()
+ return rows.Next(), rows.Err()
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return nil, false
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return nil, false
+ }
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_dual_write"
+ }
+
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ buildQuery := func() *dbent.AuthIdentityQuery {
+ return client.AuthIdentity.Query().Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(email),
+ )
+ }
+
+ existed, err := buildQuery().Exist(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+
+ if !existed {
+ if err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": strings.TrimSpace(source),
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if isSQLNoRowsError(err) {
+ return nil, false
+ }
+ }
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+ }
+
+ identity, err := buildQuery().Only(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+ if identity.UserID != user.ID {
+ logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
+ return nil, false
+ }
+
+ return identity, !existed
+}
+
+func inferLegacySignupSource(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ switch {
+ case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
+ return "linuxdo"
+ case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
+ return "oidc"
+ case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
+ return "wechat"
+ default:
+ return "email"
+ }
+}
+
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -834,7 +1031,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
- strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
+ strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
@@ -853,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
- TokenVersion: user.TokenVersion,
+ TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@@ -919,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
- if claims.TokenVersion != user.TokenVersion {
+ if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
@@ -1147,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
- TokenVersion: user.TokenVersion,
+ TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
@@ -1227,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion(密码更改后所有Token失效)
- if data.TokenVersion != user.TokenVersion {
+ if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配,撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
@@ -1272,8 +1470,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
+// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
+// Access/refresh token verification both depend on TokenVersion, so bumping it provides
+// immediate revocation even if refresh-token cache cleanup later fails.
+func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("get user: %w", err)
+ }
+
+ user.TokenVersion++
+ if err := s.userRepo.Update(ctx, user); err != nil {
+ return fmt.Errorf("update user: %w", err)
+ }
+
+ if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
+ }
+ return nil
+}
+
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
+
+func resolvedTokenVersion(user *User) int64 {
+ if user == nil {
+ return 0
+ }
+ if user.TokenVersionResolved {
+ return user.TokenVersion
+ }
+
+ material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
+ sum := sha256.Sum256([]byte(material))
+ fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
+ return user.TokenVersion ^ fingerprint
+}
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
new file mode 100644
index 00000000..cced842a
--- /dev/null
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -0,0 +1,853 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type emailBindDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
+}
+
+type flakyEmailBindDefaultSubAssignerStub struct {
+ err error
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return nil, false, s.err
+}
+
+func newAuthServiceForEmailBind(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
+}
+
+func newAuthServiceForEmailBindWithRefreshCache(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+ refreshTokenCache service.RefreshTokenCache,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingRepo := &emailBindSettingRepoStub{values: settings}
+ settingSvc := service.NewSettingService(settingRepo, cfg)
+
+ var emailSvc *service.EmailService
+ if emailCache != nil {
+ emailSvc = service.NewEmailService(settingRepo, emailCache)
+ }
+
+ svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
+ return svc, repo, client
+}
+
+func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
+ SetUsername("legacy-user").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "newemail@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "newemail@example.com", storedUser.Email)
+ require.Equal(t, 11.0, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("newemail@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, user.ID, assigner.calls[0].UserID)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ sourceUser, err := client.User.Create().
+ SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.User.Create().
+ SetEmail("taken@example.com").
+ SetUsername("taken-user").
+ SetPasswordHash("hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailExists)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, sourceUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
+ assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
+ user, err := client.User.Create().
+ SetEmail(originalEmail).
+ SetUsername("legacy-rollback").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
+ require.ErrorContains(t, err, "apply email first bind defaults")
+ require.ErrorContains(t, err, "temporary assign failure")
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, originalEmail, storedUser.Email)
+ require.Equal(t, "old-hash", storedUser.PasswordHash)
+ require.Equal(t, 2.5, storedUser.Balance)
+ require.Equal(t, 1, storedUser.Concurrency)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("rollback@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("source-user@example.com").
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailReserved)
+ require.Nil(t, updatedUser)
+}
+
+func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(7.5).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "new@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.Equal(t, 7.5, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, newIdentityCount)
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, oldIdentityCount)
+
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password")
+ require.ErrorIs(t, err, service.ErrPasswordIncorrect)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "current@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, oldIdentityCount)
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, newIdentityCount)
+}
+
+func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
+ ctx := context.Background()
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ refreshTokenCache := newEmailBindRefreshTokenCacheStub()
+ userRepo := newEmailBindUserRepoStub(&service.User{
+ ID: 41,
+ Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ PasswordHash: "old-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ })
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ }
+ emailService := service.NewEmailService(nil, cache)
+ svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
+
+ oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
+ ID: 41,
+ Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ }, "")
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+
+ storedUser, err := userRepo.GetByID(ctx, 41)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ _, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
+ require.ErrorIs(t, err, service.ErrTokenRevoked)
+
+ _, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
+ require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
+}
+
+type emailBindSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+type emailBindCacheStub struct {
+ data *service.VerificationCodeData
+ err error
+}
+
+func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.data, nil
+}
+
+func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+type emailBindRefreshTokenCacheStub struct {
+ mu sync.Mutex
+ tokens map[string]*service.RefreshTokenData
+ userSets map[int64]map[string]struct{}
+ families map[string]map[string]struct{}
+}
+
+func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
+ return &emailBindRefreshTokenCacheStub{
+ tokens: make(map[string]*service.RefreshTokenData),
+ userSets: make(map[int64]map[string]struct{}),
+ families: make(map[string]map[string]struct{}),
+ }
+}
+
+func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ cloned := *data
+ s.tokens[tokenHash] = &cloned
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ data, ok := s.tokens[tokenHash]
+ if !ok {
+ return nil, service.ErrRefreshTokenNotFound
+ }
+ cloned := *data
+ return &cloned, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.userSets {
+ delete(tokenSet, tokenHash)
+ }
+ for _, tokenSet := range s.families {
+ delete(tokenSet, tokenHash)
+ }
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for tokenHash := range s.userSets[userID] {
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.families {
+ delete(tokenSet, tokenHash)
+ }
+ }
+ delete(s.userSets, userID)
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for tokenHash := range s.families[familyID] {
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.userSets {
+ delete(tokenSet, tokenHash)
+ }
+ }
+ delete(s.families, familyID)
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.userSets[userID] == nil {
+ s.userSets[userID] = make(map[string]struct{})
+ }
+ s.userSets[userID][tokenHash] = struct{}{}
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.families[familyID] == nil {
+ s.families[familyID] = make(map[string]struct{})
+ }
+ s.families[familyID][tokenHash] = struct{}{}
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ tokenSet := s.userSets[userID]
+ out := make([]string, 0, len(tokenSet))
+ for tokenHash := range tokenSet {
+ out = append(out, tokenHash)
+ }
+ return out, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ tokenSet := s.families[familyID]
+ out := make([]string, 0, len(tokenSet))
+ for tokenHash := range tokenSet {
+ out = append(out, tokenHash)
+ }
+ return out, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ _, ok := s.families[familyID][tokenHash]
+ return ok, nil
+}
+
+type emailBindUserRepoStub struct {
+ mu sync.Mutex
+ usersByID map[int64]*service.User
+ usersByEmail map[string]*service.User
+}
+
+func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
+ cloned := cloneEmailBindUser(user)
+ return &emailBindUserRepoStub{
+ usersByID: map[int64]*service.User{
+ cloned.ID: cloned,
+ },
+ usersByEmail: map[string]*service.User{
+ cloned.Email: cloned,
+ },
+ }
+}
+
+func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
+
+func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ user, ok := s.usersByID[id]
+ if !ok {
+ return nil, service.ErrUserNotFound
+ }
+ return cloneEmailBindUser(user), nil
+}
+
+func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ user, ok := s.usersByEmail[email]
+ if !ok {
+ return nil, service.ErrUserNotFound
+ }
+ return cloneEmailBindUser(user), nil
+}
+
+func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ existing, ok := s.usersByID[user.ID]
+ if !ok {
+ return service.ErrUserNotFound
+ }
+ delete(s.usersByEmail, existing.Email)
+ cloned := cloneEmailBindUser(user)
+ s.usersByID[user.ID] = cloned
+ s.usersByEmail[cloned.Email] = cloned
+ return nil
+}
+
+func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ _, ok := s.usersByEmail[email]
+ return ok, nil
+}
+
+func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func cloneEmailBindUser(user *service.User) *service.User {
+ if user == nil {
+ return nil
+ }
+ cloned := *user
+ return &cloned
+}
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
new file mode 100644
index 00000000..2233e427
--- /dev/null
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -0,0 +1,482 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type authIdentityDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
+type flakyAuthIdentityDefaultSubAssignerStub struct {
+ failuresRemaining int
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ if s.failuresRemaining > 0 {
+ s.failuresRemaining--
+ return nil, false, errors.New("temporary assign failure")
+ }
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
+type authIdentitySettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func newAuthServiceWithEnt(
+ t *testing.T,
+ settings map[string]string,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-auth-identity-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+ settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
+ values: settings,
+ }, cfg)
+
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
+ return svc, repo, client
+}
+
+func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ token, user, err := svc.Register(ctx, "user@example.com", "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, user)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "email", storedUser.SignupSource)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("user@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.NotNil(t, identity.VerifiedAt)
+}
+
+func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("login@example.com").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetBalance(1).
+ SetConcurrency(1).
+ Save(ctx)
+ require.NoError(t, err)
+
+ old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
+ _, err = client.User.UpdateOneID(user.ID).
+ SetLastLoginAt(old).
+ SetLastActiveAt(old).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+ require.True(t, storedUser.LastLoginAt.Equal(old))
+ require.True(t, storedUser.LastActiveAt.Equal(old))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "record@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 1,
+ Concurrency: 1,
+ }
+ require.NoError(t, user.SetPassword("password"))
+ require.NoError(t, repo.Create(ctx, user))
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("record@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("legacy@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("merged-first-bind@example.com").
+ SetUsername("merged-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("bound@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(2).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("bound@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "preexisting"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2.0, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) {
+ assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("retry-first-bind@example.com").
+ SetUsername("retry-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var count int
+ rows, err := client.QueryContext(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ userID,
+ providerType,
+ grantReason,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+ require.True(t, rows.Next())
+ require.NoError(t, rows.Scan(&count))
+ require.NoError(t, rows.Err())
+ return count
+}
diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go
deleted file mode 100644
index 0472e06c..00000000
--- a/backend/internal/service/auth_service_pending_oauth_test.go
+++ /dev/null
@@ -1,146 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/require"
-)
-
-func newAuthServiceForPendingOAuthTest() *AuthService {
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret-pending-oauth",
- ExpireHour: 1,
- },
- }
- return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
-}
-
-// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
-func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
- require.NotEmpty(t, token)
-
- email, username, err := svc.VerifyPendingOAuthToken(token)
- require.NoError(t, err)
- require.Equal(t, "user@example.com", email)
- require.Equal(t, "alice", username)
-}
-
-// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- // 签发一个普通 access token(JWTClaims,无 Purpose 字段)
- accessToken, err := svc.GenerateToken(&User{
- ID: 1,
- Email: "user@example.com",
- Role: RoleUser,
- })
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(accessToken)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "some_other_purpose",
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "", // 旧 token 无此字段,反序列化后为零值
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- past := time.Now().Add(-1 * time.Hour)
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(past),
- IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
- other := NewAuthService(nil, nil, nil, nil, &config.Config{
- JWT: config.JWTConfig{Secret: "other-secret"},
- }, nil, nil, nil, nil, nil, nil)
-
- token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
-
- svc := newAuthServiceForPendingOAuthTest()
- _, _, err = svc.VerifyPendingOAuthToken(token)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
- giant := make([]byte, maxTokenLength+1)
- for i := range giant {
- giant[i] = 'a'
- }
- _, _, err := svc.VerifyPendingOAuthToken(string(giant))
- require.ErrorIs(t, err, ErrInvalidToken)
-}
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 103bafe7..dbd18a20 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+ if s.err != nil {
+ return nil, s.err
+ }
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error
}
+type refreshTokenCacheStub struct{}
+
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
+ return nil, ErrRefreshTokenNotFound
+}
+
+func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -322,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
@@ -469,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
@@ -484,3 +537,132 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
+
+func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
+ repo := &userRepoStub{nextID: 52}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 12.5, user.Balance)
+ require.Equal(t, 7, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 53}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "99",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "88",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-global@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 3.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 54}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 9.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
+ repo := &userRepoStub{nextID: 61}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Equal(t, int64(61), user.ID)
+ require.Equal(t, 21.75, user.Balance)
+ require.Equal(t, 9, user.Concurrency)
+ require.Len(t, repo.created, 1)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(22), assigner.calls[0].GroupID)
+ require.Equal(t, 14, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
+ existing := &User{
+ ID: 88,
+ Email: "linuxdo-123@linuxdo-connect.invalid",
+ Username: "existing-linuxdo",
+ Role: RoleUser,
+ Status: StatusActive,
+ Balance: 4,
+ Concurrency: 1,
+ TokenVersion: 2,
+ }
+ repo := &userRepoStub{user: existing}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.Equal(t, existing.ID, user.ID)
+ require.Equal(t, 4.0, user.Balance)
+ require.Equal(t, 1, user.Concurrency)
+ require.Empty(t, repo.created)
+ require.Empty(t, assigner.calls)
+}
diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go
index 4a8b8f03..0eaf4570 100644
--- a/backend/internal/service/billing_cache_service_singleflight_test.go
+++ b/backend/internal/service/billing_cache_service_singleflight_test.go
@@ -86,6 +86,14 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User,
return &User{ID: id, Balance: s.balance}, nil
}
+func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return nil
+}
+
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 32a54cbe..a45203a3 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() {
SupportsCacheBreakdown: false,
}
- // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
- s.fallbackPrices["gpt-5.1"] = &ModelPricing{
- InputPricePerToken: 1.25e-6, // $1.25 per MTok
- InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
- OutputPricePerToken: 10e-6, // $10 per MTok
- OutputPricePerTokenPriority: 20e-6, // $20 per MTok
- CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
- CacheReadPricePerToken: 0.125e-6,
- CacheReadPricePerTokenPriority: 0.25e-6,
- SupportsCacheBreakdown: false,
- }
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
@@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 7.5e-8,
SupportsCacheBreakdown: false,
}
- s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
- InputPricePerToken: 2e-7,
- OutputPricePerToken: 1.25e-6,
- CacheReadPricePerToken: 2e-8,
- SupportsCacheBreakdown: false,
- }
// OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
@@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
- // Codex 族兜底统一按 GPT-5.1 Codex 价格计费
- s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
+ // Codex 族兜底统一按 GPT-5.3 Codex 价格计费
+ s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
InputPricePerTokenPriority: 3e-6, // $3 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
@@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.3e-6,
SupportsCacheBreakdown: false,
}
- s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
- InputPricePerToken: 1.75e-6,
- InputPricePerTokenPriority: 3.5e-6,
- OutputPricePerToken: 14e-6,
- OutputPricePerTokenPriority: 28e-6,
- CacheCreationPricePerToken: 1.75e-6,
- CacheReadPricePerToken: 0.175e-6,
- CacheReadPricePerTokenPriority: 0.35e-6,
- SupportsCacheBreakdown: false,
- }
- s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
switch normalized {
case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"]
- case "gpt-5.4-nano":
- return s.fallbackPrices["gpt-5.4-nano"]
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2":
return s.fallbackPrices["gpt-5.2"]
- case "gpt-5.2-codex":
- return s.fallbackPrices["gpt-5.2-codex"]
- case "gpt-5.3-codex":
+ case "gpt-5.3-codex", "gpt-5.3-codex-spark":
return s.fallbackPrices["gpt-5.3-codex"]
- case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
- return s.fallbackPrices["gpt-5.1-codex"]
- case "gpt-5.1":
- return s.fallbackPrices["gpt-5.1"]
}
}
@@ -448,8 +412,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
})
}
- if input.RateMultiplier <= 0 {
- input.RateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
+ if input.RateMultiplier < 0 {
+ input.RateMultiplier = 0
}
var breakdown *CostBreakdown
@@ -493,8 +458,9 @@ func (s *BillingService) computeTokenBreakdown(
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
inputPrice := pricing.InputPricePerToken
@@ -665,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
}
func isOpenAIGPT54Model(model string) bool {
- normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
- return normalized == "gpt-5.4"
+ trimmed := strings.TrimSpace(strings.ToLower(model))
+ // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
+ // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
+ if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
+ return false
+ }
+ return normalizeCodexModel(trimmed) == "gpt-5.4"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
@@ -831,9 +802,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
// 计算总费用
totalCost := unitPrice * float64(imageCount)
- // 应用倍率
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
actualCost := totalCost * rateMultiplier
diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go
index fa90f6bb..8d3ca987 100644
--- a/backend/internal/service/billing_service_image_test.go
+++ b/backend/internal/service/billing_service_image_test.go
@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
require.Equal(t, 0.0, cost.ActualCost)
}
-// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
+// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
+// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go
new file mode 100644
index 00000000..83788196
--- /dev/null
+++ b/backend/internal/service/billing_service_rate_multiplier_test.go
@@ -0,0 +1,63 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
+// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
+func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64 // ActualCost / TotalCost
+ }{
+ {"negative clamped to 0", -1.5, 0},
+ {"zero passes through as 0 (defense in depth)", 0, 0},
+ {"positive 2x applied", 2.0, 2.0},
+ {"positive 0.5x applied", 0.5, 0.5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
+ require.NoError(t, err)
+ require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
+
+// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
+// 同样遵循"负数 → 0"语义。
+func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ price := 0.04
+ cfg := &ImagePriceConfig{Price1K: &price}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64
+ }{
+ {"negative clamped to 0", -0.5, 0},
+ {"zero passes through", 0, 0},
+ {"positive 3x applied", 3.0, 3.0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
+ require.NotNil(t, cost)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index 2cf134e2..222abd69 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
}
-func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
-}
-
-func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
-}
-
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService()
@@ -151,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
require.Contains(t, err.Error(), "pricing not found")
}
-func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
- svc := newTestBillingService()
-
- pricing, err := svc.GetModelPricing("gpt-5.1")
- require.NoError(t, err)
- require.NotNil(t, pricing)
- require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
-}
-
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService()
@@ -186,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
require.Zero(t, pricing.LongContextInputThreshold)
}
-func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) {
- svc := newTestBillingService()
-
- pricing, err := svc.GetModelPricing("gpt-5.4-nano")
- require.NoError(t, err)
- require.NotNil(t, pricing)
- require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12)
- require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12)
- require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12)
- require.Zero(t, pricing.LongContextInputThreshold)
-}
-
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService()
@@ -232,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
- {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7},
- {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
- {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
- {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
+ {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6},
+ {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6},
+ {name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6},
+ {name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
}
diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go
index 694c3384..e6a92d1a 100644
--- a/backend/internal/service/billing_service_unified_test.go
+++ b/backend/internal/service/billing_service_unified_test.go
@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
-func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
+// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
+func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
- costZero, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
- RateMultiplier: 0, // should default to 1.0
+ RateMultiplier: 0,
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
-func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
+// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
+func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
- costNeg, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index cb452efb..3c6888b8 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
+// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
+const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
+
// Setting keys
const (
// 注册设置
@@ -108,6 +111,24 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
+ // WeChat Connect OAuth 登录设置
+ SettingKeyWeChatConnectEnabled = "wechat_connect_enabled"
+ SettingKeyWeChatConnectAppID = "wechat_connect_app_id"
+ SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret"
+ SettingKeyWeChatConnectOpenAppID = "wechat_connect_open_app_id"
+ SettingKeyWeChatConnectOpenAppSecret = "wechat_connect_open_app_secret"
+ SettingKeyWeChatConnectMPAppID = "wechat_connect_mp_app_id"
+ SettingKeyWeChatConnectMPAppSecret = "wechat_connect_mp_app_secret"
+ SettingKeyWeChatConnectMobileAppID = "wechat_connect_mobile_app_id"
+ SettingKeyWeChatConnectMobileAppSecret = "wechat_connect_mobile_app_secret"
+ SettingKeyWeChatConnectOpenEnabled = "wechat_connect_open_enabled"
+ SettingKeyWeChatConnectMPEnabled = "wechat_connect_mp_enabled"
+ SettingKeyWeChatConnectMobileEnabled = "wechat_connect_mobile_enabled"
+ SettingKeyWeChatConnectMode = "wechat_connect_mode"
+ SettingKeyWeChatConnectScopes = "wechat_connect_scopes"
+ SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url"
+ SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url"
+
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
@@ -153,6 +174,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
+ // 第三方认证来源默认授予配置
+ SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
+ SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
+ SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
+ SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
+ SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
+ SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
+ SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
+ SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
+
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go
index 55cb2c84..498336a4 100644
--- a/backend/internal/service/gateway_request.go
+++ b/backend/internal/service/gateway_request.go
@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string {
return nil
}
switch value {
- case "low", "medium", "high", "max":
+ case "low", "medium", "high", "xhigh", "max":
return &value
default:
return nil
diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go
index d262456d..40bd1186 100644
--- a/backend/internal/service/gateway_request_test.go
+++ b/backend/internal/service/gateway_request_test.go
@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) {
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
wantEffort: "max",
},
+ {
+ name: "output_config.effort xhigh",
+ body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`,
+ wantEffort: "xhigh",
+ },
{
name: "output_config without effort",
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) {
{"LOW", strPtr("low")},
{"Max", strPtr("max")},
{" medium ", strPtr("medium")},
+ {"xhigh", strPtr("xhigh")},
+ {"XHIGH", strPtr("xhigh")},
{"", nil},
{"unknown", nil},
- {"xhigh", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 91740ad0..2497d3d0 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -445,26 +445,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
-// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
-// 或请求的模型处于限流状态时,返回 true。
-// 这确保后续请求不会继续使用不可用的账号。
+// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等),
+// 额外检查模型级限流。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
-// Returns true when account status is error/disabled, schedulable is false,
-// within temporary unschedulable period, or the requested model is rate-limited.
-// This ensures subsequent requests won't continue using unavailable accounts.
+// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
- if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
+ if !account.IsSchedulable() {
return true
}
- if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
- return true
- }
- // 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
@@ -7419,8 +7412,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
cost := p.Cost
if p.IsSubscriptionBill {
- if cost.TotalCost > 0 {
- if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
+ // Subscription usage tracked by ActualCost so group rate multiplier
+ // consumes the quota at the expected speed.
+ if cost.ActualCost > 0 {
+ if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
}
@@ -7519,9 +7514,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
}
}
+ // Record subscription / balance cost using ActualCost so the group (and any
+ // user-specific) rate multiplier consumes subscription quota at the expected
+ // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
+ // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID
- cmd.SubscriptionCost = p.Cost.TotalCost
+ cmd.SubscriptionCost = p.Cost.ActualCost
} else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost
}
@@ -7580,8 +7579,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
}
if p.IsSubscriptionBill {
- if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
- deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
+ if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
+ deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
}
} else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go
new file mode 100644
index 00000000..42a81035
--- /dev/null
+++ b/backend/internal/service/gateway_service_subscription_billing_test.go
@@ -0,0 +1,85 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+)
+
+// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
+// that subscription-mode billing honours the group (and any user-specific) rate
+// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
+// RateMultiplier), not raw TotalCost.
+func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
+ t.Parallel()
+
+ groupID := int64(7)
+ subID := int64(42)
+
+ tests := []struct {
+ name string
+ totalCost float64
+ actualCost float64
+ isSubscription bool
+ wantSub float64
+ wantBalance float64
+ }{
+ {
+ name: "subscription with 2x multiplier consumes 2x quota",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: true,
+ wantSub: 2.0,
+ wantBalance: 0,
+ },
+ {
+ name: "subscription with 0.5x multiplier consumes 0.5x quota",
+ totalCost: 1.0,
+ actualCost: 0.5,
+ isSubscription: true,
+ wantSub: 0.5,
+ wantBalance: 0,
+ },
+ {
+ name: "free subscription (multiplier 0) consumes no quota",
+ totalCost: 1.0,
+ actualCost: 0,
+ isSubscription: true,
+ wantSub: 0,
+ wantBalance: 0,
+ },
+ {
+ name: "balance billing keeps using ActualCost (regression)",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: false,
+ wantSub: 0,
+ wantBalance: 2.0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ p := &postUsageBillingParams{
+ Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
+ User: &User{ID: 1},
+ APIKey: &APIKey{ID: 2, GroupID: &groupID},
+ Account: &Account{ID: 3},
+ Subscription: &UserSubscription{ID: subID},
+ IsSubscriptionBill: tt.isSubscription,
+ }
+
+ cmd := buildUsageBillingCommand("req-1", nil, p)
+ if cmd == nil {
+ t.Fatal("buildUsageBillingCommand returned nil")
+ }
+ if cmd.SubscriptionCost != tt.wantSub {
+ t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
+ }
+ if cmd.BalanceCost != tt.wantBalance {
+ t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 12262613..64434ae1 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
-func (g *Group) IsFreeSubscription() bool {
- return g.IsSubscriptionType() && g.RateMultiplier == 0
-}
-
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go
index b7ca4cb7..58089776 100644
--- a/backend/internal/service/model_pricing_resolver.go
+++ b/backend/internal/service/model_pricing_resolver.go
@@ -61,6 +61,25 @@ type PricingInput struct {
// 1. 获取基础定价(LiteLLM → Fallback)
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
+ var chPricing *ChannelModelPricing
+ if input.GroupID != nil && r.channelService != nil {
+ chPricing = r.channelService.GetChannelModelPricing(ctx, *input.GroupID, input.Model)
+ if chPricing != nil {
+ mode := chPricing.BillingMode
+ if mode == "" {
+ mode = BillingModeToken
+ }
+ if mode == BillingModePerRequest || mode == BillingModeImage {
+ resolved := &ResolvedPricing{
+ Mode: mode,
+ Source: PricingSourceChannel,
+ }
+ r.applyRequestTierOverrides(chPricing, resolved)
+ return resolved
+ }
+ }
+ }
+
// 1. 获取基础定价
basePricing, source := r.resolveBasePricing(input.Model)
@@ -72,7 +91,10 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput)
}
// 2. 如果有 GroupID,尝试渠道覆盖
- if input.GroupID != nil {
+ if chPricing != nil {
+ resolved.Source = PricingSourceChannel
+ r.applyTokenOverrides(chPricing, resolved)
+ } else if input.GroupID != nil {
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
}
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 6c09e354..808f1229 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -13,22 +13,39 @@ import (
"sync"
"sync/atomic"
"time"
+
+ "golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
+ openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
+const (
+ openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
+ openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
+)
+
+type cachedOpenAIAdvancedSchedulerSetting struct {
+ enabled bool
+ expiresAt int64
+}
+
+var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
+var openAIAdvancedSchedulerSettingSF singleflight.Group
+
type OpenAIAccountScheduleRequest struct {
- GroupID *int64
- SessionHash string
- StickyAccountID int64
- PreviousResponseID string
- RequestedModel string
- RequiredTransport OpenAIUpstreamTransport
- ExcludedIDs map[int64]struct{}
+ GroupID *int64
+ SessionHash string
+ StickyAccountID int64
+ PreviousResponseID string
+ RequestedModel string
+ RequiredTransport OpenAIUpstreamTransport
+ RequiredImageCapability OpenAIImagesCapability
+ ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
@@ -324,7 +341,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
- if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ if !s.isAccountRequestCompatible(account, req) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@@ -600,7 +617,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
- if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ if !s.isAccountRequestCompatible(account, req) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@@ -706,11 +723,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
@@ -733,7 +750,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
return &AccountSelectionResult{
@@ -751,14 +768,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
- // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
- if s == nil || s.service == nil || account == nil {
+ if s == nil || s.service == nil {
return false
}
- return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+ return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
+}
+
+func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
+ if account == nil {
+ return false
+ }
+ if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ return false
+ }
+ return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
@@ -805,10 +831,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
-func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
+func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
+ if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
+ return nil
+ }
+ return s.rateLimitService.settingService.settingRepo
+}
+
+func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled
+ }
+ }
+
+ result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled, nil
+ }
+ }
+
+ enabled := false
+ if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
+ dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
+ defer cancel()
+
+ value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
+ if err == nil {
+ enabled = strings.EqualFold(strings.TrimSpace(value), "true")
+ }
+ }
+
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: enabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ return enabled, nil
+ })
+
+ enabled, _ := result.(bool)
+ return enabled
+}
+
+func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
+ if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
+ return nil
+ }
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
@@ -820,6 +892,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
+func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
+ openAIAdvancedSchedulerSettingCache = atomic.Value{}
+ openAIAdvancedSchedulerSettingSF = singleflight.Group{}
+}
+
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
@@ -828,13 +905,92 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
+) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
+ return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
+}
+
+func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredCapability OpenAIImagesCapability,
+) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
+ selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
+ if err == nil && selection != nil && selection.Account != nil {
+ return selection, decision, nil
+ }
+ // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号)
+ if requiredCapability == OpenAIImagesCapabilityNative {
+ return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic)
+ }
+ return selection, decision, err
+}
+
+func (s *OpenAIGatewayService) selectAccountWithScheduler(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport OpenAIUpstreamTransport,
+ requiredImageCapability OpenAIImagesCapability,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
- selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
- return selection, decision, err
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
+ for {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection == nil || selection.Account == nil {
+ return selection, decision, nil
+ }
+ if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) {
+ return selection, decision, nil
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if effectiveExcludedIDs == nil {
+ effectiveExcludedIDs = make(map[int64]struct{})
+ }
+ if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
+ return nil, decision, ErrNoAvailableAccounts
+ }
+ effectiveExcludedIDs[selection.Account.ID] = struct{}{}
+ }
+ }
+
+ effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
+ for {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection == nil || selection.Account == nil {
+ return selection, decision, nil
+ }
+ if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
+ return selection, decision, nil
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if effectiveExcludedIDs == nil {
+ effectiveExcludedIDs = make(map[int64]struct{})
+ }
+ if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
+ return nil, decision, ErrNoAvailableAccounts
+ }
+ effectiveExcludedIDs[selection.Account.ID] = struct{}{}
+ }
}
var stickyAccountID int64
@@ -845,18 +1001,40 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
- GroupID: groupID,
- SessionHash: sessionHash,
- StickyAccountID: stickyAccountID,
- PreviousResponseID: previousResponseID,
- RequestedModel: requestedModel,
- RequiredTransport: requiredTransport,
- ExcludedIDs: excludedIDs,
+ GroupID: groupID,
+ SessionHash: sessionHash,
+ StickyAccountID: stickyAccountID,
+ PreviousResponseID: previousResponseID,
+ RequestedModel: requestedModel,
+ RequiredTransport: requiredTransport,
+ RequiredImageCapability: requiredImageCapability,
+ ExcludedIDs: excludedIDs,
})
}
+func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
+ if len(excludedIDs) == 0 {
+ return nil
+ }
+ cloned := make(map[int64]struct{}, len(excludedIDs))
+ for id := range excludedIDs {
+ cloned[id] = struct{}{}
+ }
+ return cloned
+}
+
+func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
+ if s == nil || account == nil {
+ return false
+ }
+ return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+}
+
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -864,7 +1042,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -872,7 +1050,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index 088815ed..b02370cb 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "errors"
"fmt"
"math"
"sync"
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
+type schedulerTestOpenAIAccountRepo struct {
+ AccountRepository
+ accounts []Account
+}
+
+func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, errors.New("account not found")
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ return r.ListSchedulableByPlatform(ctx, platform)
+}
+
+type schedulerTestConcurrencyCache struct {
+ ConcurrencyCache
+ loadBatchErr error
+ loadMap map[int64]*AccountLoadInfo
+ acquireResults map[int64]bool
+ waitCounts map[int64]int
+ skipDefaultLoad bool
+}
+
+func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ if c.acquireResults != nil {
+ if result, ok := c.acquireResults[accountID]; ok {
+ return result, nil
+ }
+ }
+ return true, nil
+}
+
+func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
+ return nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
+ if c.loadBatchErr != nil {
+ return nil, c.loadBatchErr
+ }
+ out := make(map[int64]*AccountLoadInfo, len(accounts))
+ if c.skipDefaultLoad && c.loadMap != nil {
+ for _, acc := range accounts {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ }
+ }
+ return out, nil
+ }
+ for _, acc := range accounts {
+ if c.loadMap != nil {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ continue
+ }
+ }
+ out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
+ }
+ return out, nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
+ if c.waitCounts != nil {
+ if count, ok := c.waitCounts[accountID]; ok {
+ return count, nil
+ }
+ }
+ return 0, nil
+}
+
+type schedulerTestGatewayCache struct {
+ sessionBindings map[string]int64
+ deletedSessions map[string]int
+}
+
+func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
+ if id, ok := c.sessionBindings[sessionHash]; ok {
+ return id, nil
+ }
+ return 0, errors.New("not found")
+}
+
+func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
+ if c.sessionBindings == nil {
+ c.sessionBindings = make(map[string]int64)
+ }
+ c.sessionBindings[sessionHash] = accountID
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if c.sessionBindings == nil {
+ return nil
+ }
+ if c.deletedSessions == nil {
+ c.deletedSessions = make(map[string]int)
+ }
+ c.deletedSessions[sessionHash]++
+ delete(c.sessionBindings, sessionHash)
+ return nil
+}
+
+func newSchedulerTestOpenAIWSV2Config() *config.Config {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ return cfg
+}
+
+type openAIAdvancedSchedulerSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &Setting{Key: key, Value: value}, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if s == nil || s.values == nil {
+ return "", ErrSettingNotFound
+ }
+ value, ok := s.values[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected call to Set")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ panic("unexpected call to GetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected call to SetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected call to GetAll")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected call to Delete")
+}
+
+func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+ repo := &openAIAdvancedSchedulerSettingRepoStub{
+ values: map[string]string{},
+ }
+ if enabled != "" {
+ repo.values[openAIAdvancedSchedulerSettingKey] = enabled
+ }
+ return &RateLimitService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+}
+
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
@@ -45,6 +242,230 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10106)
+ accounts := []Account{
+ {
+ ID: 36001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ },
+ {
+ ID: 36002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cache := &schedulerTestGatewayCache{}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
+ require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_disabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36002), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.False(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10108)
+ accounts := []Account{
+ {
+ ID: 36011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 36012,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36012), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10109)
+ accounts := []Account{
+ {
+ ID: 36021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.ErrorContains(t, err, "no available OpenAI accounts")
+ require.Nil(t, selection)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10107)
+ accounts := []Account{
+ {
+ ID: 37001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ {
+ ID: 37002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
+ require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_enabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(37001), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
+ require.True(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ svc := &OpenAIGatewayService{}
+ ttft := 120
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.RecordOpenAIAccountSwitch()
+
+ snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
@@ -53,10 +474,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
+ cache: cache,
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
@@ -76,7 +504,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ }
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
@@ -92,18 +525,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -128,8 +562,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
@@ -153,7 +588,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
- cache := &stubGatewayCache{}
+ cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
@@ -163,10 +598,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
@@ -204,17 +640,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -260,7 +697,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
@@ -273,7 +710,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
@@ -288,9 +725,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -328,17 +766,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -387,15 +826,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
@@ -403,9 +842,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -445,10 +885,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
- cfg: newOpenAIWSV2TestConfig(),
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: newSchedulerTestOpenAIWSV2Config(),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -507,7 +948,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
@@ -520,9 +961,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -559,16 +1001,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -749,7 +1192,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
@@ -757,9 +1200,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -905,12 +1349,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
- require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
@@ -947,7 +1393,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
index c5de8203..ddafc6eb 100644
--- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
+++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index a266d6a0..457309d3 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -8,7 +8,6 @@ import (
var codexModelMap = map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
- "gpt-5.4-nano": "gpt-5.4-nano",
"gpt-5.4-none": "gpt-5.4",
"gpt-5.4-low": "gpt-5.4",
"gpt-5.4-medium": "gpt-5.4",
@@ -22,52 +21,21 @@ var codexModelMap = map[string]string{
"gpt-5.3-high": "gpt-5.3-codex",
"gpt-5.3-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-low": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-medium": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
- "gpt-5.1-codex": "gpt-5.1-codex",
- "gpt-5.1-codex-low": "gpt-5.1-codex",
- "gpt-5.1-codex-medium": "gpt-5.1-codex",
- "gpt-5.1-codex-high": "gpt-5.1-codex",
- "gpt-5.1-codex-max": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
- "gpt-5.2-codex": "gpt-5.2-codex",
- "gpt-5.2-codex-low": "gpt-5.2-codex",
- "gpt-5.2-codex-medium": "gpt-5.2-codex",
- "gpt-5.2-codex-high": "gpt-5.2-codex",
- "gpt-5.2-codex-xhigh": "gpt-5.2-codex",
- "gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
- "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
- "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
- "gpt-5.1": "gpt-5.1",
- "gpt-5.1-none": "gpt-5.1",
- "gpt-5.1-low": "gpt-5.1",
- "gpt-5.1-medium": "gpt-5.1",
- "gpt-5.1-high": "gpt-5.1",
- "gpt-5.1-chat-latest": "gpt-5.1",
- "gpt-5-codex": "gpt-5.1-codex",
- "codex-mini-latest": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
- "gpt-5": "gpt-5.1",
- "gpt-5-mini": "gpt-5.1",
- "gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
@@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
func normalizeCodexModel(model string) string {
if model == "" {
- return "gpt-5.1"
+ return "gpt-5.4"
}
modelID := model
@@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string {
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
return "gpt-5.4-mini"
}
- if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") {
- return "gpt-5.4-nano"
- }
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
return "gpt-5.4"
}
- if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
- return "gpt-5.2-codex"
- }
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
+ if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
+ return "gpt-5.3-codex-spark"
+ }
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3-codex"
}
- if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
- return "gpt-5.1-codex-max"
- }
- if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
- return "gpt-5.1-codex-mini"
- }
- if strings.Contains(normalized, "codex-mini-latest") ||
- strings.Contains(normalized, "gpt-5-codex-mini") ||
- strings.Contains(normalized, "gpt 5 codex mini") {
- return "codex-mini-latest"
- }
- if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
- return "gpt-5.1-codex"
- }
- if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
- return "gpt-5.1"
- }
if strings.Contains(normalized, "codex") {
- return "gpt-5.1-codex"
+ return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
- return "gpt-5.1"
+ return "gpt-5.4"
}
- return "gpt-5.1"
+ return "gpt-5.4"
}
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 993ade07..22264f5e 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini",
- "gpt-5.4-nano": "gpt-5.4-nano",
- "gpt 5.4 nano": "gpt-5.4-nano",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt 5.3 codex spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt 5.3 codex spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt 5.3 codex": "gpt-5.3-codex",
}
@@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
}
}
+func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) {
+ cases := map[string]string{
+ "": "gpt-5.4",
+ "gpt-5": "gpt-5.4",
+ "gpt-5-mini": "gpt-5.4",
+ "gpt-5-nano": "gpt-5.4",
+ "gpt-5.1": "gpt-5.4",
+ "gpt-5.1-codex": "gpt-5.3-codex",
+ "gpt-5.1-codex-max": "gpt-5.3-codex",
+ "gpt-5.1-codex-mini": "gpt-5.3-codex",
+ "gpt-5.2-codex": "gpt-5.2",
+ "codex-mini-latest": "gpt-5.3-codex",
+ "gpt-5-codex": "gpt-5.3-codex",
+ }
+
+ for input, expected := range cases {
+ require.Equal(t, expected, normalizeCodexModel(input))
+ }
+}
+
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go
index 88e16a4d..fcd27f19 100644
--- a/backend/internal/service/openai_compat_prompt_cache_key.go
+++ b/backend/internal/service/openai_compat_prompt_cache_key.go
@@ -10,8 +10,14 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
- switch normalizeCodexModel(strings.TrimSpace(model)) {
- case "gpt-5.4", "gpt-5.3-codex":
+ trimmed := strings.TrimSpace(strings.ToLower(model))
+ // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel
+ // 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。
+ if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
+ return false
+ }
+ switch normalizeCodexModel(trimmed) {
+ case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
return true
default:
return false
diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go
index ac7d28a7..663066a3 100644
--- a/backend/internal/service/openai_gateway_chat_completions.go
+++ b/backend/internal/service/openai_gateway_chat_completions.go
@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
responsesBody = stripped
}
}
+ responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
+ if err != nil {
+ return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
+ }
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
- ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
+ ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
+ normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return result, handleErr
}
+func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
+ if req == nil {
+ return
+ }
+ req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
+}
+
+func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
+ if len(body) == 0 {
+ return body, "", nil
+ }
+ rawServiceTier := gjson.GetBytes(body, "service_tier").String()
+ if rawServiceTier == "" {
+ return body, "", nil
+ }
+ normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
+ if normalizedServiceTier == "" {
+ trimmed, err := sjson.DeleteBytes(body, "service_tier")
+ return trimmed, "", err
+ }
+ if normalizedServiceTier == rawServiceTier {
+ return body, normalizedServiceTier, nil
+ }
+ trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
+ return trimmed, normalizedServiceTier, err
+}
+
+func normalizedOpenAIServiceTierValue(raw string) string {
+ normalized := normalizeOpenAIServiceTier(raw)
+ if normalized == nil {
+ return ""
+ }
+ return *normalized
+}
+
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go
new file mode 100644
index 00000000..a00fb71c
--- /dev/null
+++ b/backend/internal/service/openai_gateway_chat_completions_test.go
@@ -0,0 +1,44 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
+ t.Parallel()
+
+ req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "priority", req.ServiceTier)
+
+ req.ServiceTier = "flex"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "flex", req.ServiceTier)
+
+ req.ServiceTier = "default"
+ normalizeResponsesRequestServiceTier(req)
+ require.Empty(t, req.ServiceTier)
+}
+
+func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
+ t.Parallel()
+
+ body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
+ require.NoError(t, err)
+ require.Equal(t, "priority", tier)
+ require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
+ require.NoError(t, err)
+ require.Equal(t, "flex", tier)
+ require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
+ require.NoError(t, err)
+ require.Empty(t, tier)
+ require.False(t, gjson.GetBytes(body, "service_tier").Exists())
+}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index e6fa94aa..95e1bffa 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
Model: "gpt-5.1",
Duration: time.Second,
},
- APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
+ APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
User: &User{ID: 200},
Account: &Account{ID: 300},
Subscription: subscription,
@@ -1070,3 +1070,31 @@ func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *t
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
+
+func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_only_usage",
+ Model: "gpt-image-2",
+ ImageCount: 2,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 1007},
+ User: &User{ID: 2007},
+ Account: &Account{ID: 3007},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, 2, usageRepo.lastLog.ImageCount)
+ require.NotNil(t, usageRepo.lastLog.ImageSize)
+ require.Equal(t, "1K", *usageRepo.lastLog.ImageSize)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 064191bd..a4a7ff1b 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -233,6 +233,8 @@ type OpenAIForwardResult struct {
ResponseHeaders http.Header
Duration time.Duration
FirstTokenMs *int
+ ImageCount int
+ ImageSize string
}
type OpenAIWSRetryMetricsSnapshot struct {
@@ -3889,6 +3891,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
+ usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int())
}
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
@@ -3900,11 +3903,13 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
"usage.input_tokens",
"usage.output_tokens",
"usage.input_tokens_details.cached_tokens",
+ "usage.output_tokens_details.image_tokens",
)
return OpenAIUsage{
InputTokens: int(values[0].Int()),
OutputTokens: int(values[1].Int()),
CacheReadInputTokens: int(values[2].Int()),
+ ImageOutputTokens: int(values[3].Int()),
}, true
}
@@ -4397,7 +4402,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
- result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
+ result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
+ result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
return nil
}
@@ -4451,21 +4457,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
- if s.resolver != nil && apiKey.Group != nil {
- gid := apiKey.Group.ID
- cost, err = s.billingService.CalculateCostUnified(CostInput{
- Ctx: ctx,
- Model: billingModel,
- GroupID: &gid,
- Tokens: tokens,
- RequestCount: 1,
- RateMultiplier: multiplier,
- ServiceTier: serviceTier,
- Resolver: s.resolver,
- })
- } else {
- cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
- }
+ cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
@@ -4505,6 +4497,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
+ ImageCount: result.ImageCount,
+ ImageSize: optionalTrimmedStringPtr(result.ImageSize),
}
if cost != nil {
usageLog.InputCost = cost.InputCost
@@ -4530,6 +4524,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
+ } else if result.ImageCount > 0 {
+ billingMode := string(BillingModeImage)
+ usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
@@ -4589,6 +4586,125 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil
}
+func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
+ ctx context.Context,
+ result *OpenAIForwardResult,
+ apiKey *APIKey,
+ billingModel string,
+ multiplier float64,
+ tokens UsageTokens,
+ serviceTier string,
+) (*CostBreakdown, error) {
+ if result != nil && result.ImageCount > 0 {
+ if hasOpenAIImageUsageTokens(result) {
+ cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
+ if err == nil {
+ return cost, nil
+ }
+ }
+ return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
+ }
+ if s.resolver != nil && apiKey.Group != nil {
+ gid := apiKey.Group.ID
+ return s.billingService.CalculateCostUnified(CostInput{
+ Ctx: ctx,
+ Model: billingModel,
+ GroupID: &gid,
+ Tokens: tokens,
+ RequestCount: 1,
+ RateMultiplier: multiplier,
+ ServiceTier: serviceTier,
+ Resolver: s.resolver,
+ })
+ }
+ return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
+}
+
+func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost(
+ ctx context.Context,
+ apiKey *APIKey,
+ billingModel string,
+ multiplier float64,
+ tokens UsageTokens,
+ serviceTier string,
+ sizeTier string,
+) (*CostBreakdown, error) {
+ if s.resolver != nil && apiKey.Group != nil {
+ gid := apiKey.Group.ID
+ return s.billingService.CalculateCostUnified(CostInput{
+ Ctx: ctx,
+ Model: billingModel,
+ GroupID: &gid,
+ Tokens: tokens,
+ RequestCount: 1,
+ SizeTier: sizeTier,
+ RateMultiplier: multiplier,
+ ServiceTier: serviceTier,
+ Resolver: s.resolver,
+ })
+ }
+ return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
+}
+
+func (s *OpenAIGatewayService) calculateOpenAIImageCost(
+ ctx context.Context,
+ billingModel string,
+ apiKey *APIKey,
+ result *OpenAIForwardResult,
+ multiplier float64,
+) *CostBreakdown {
+ if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil {
+ gid := apiKey.Group.ID
+ cost, err := s.billingService.CalculateCostUnified(CostInput{
+ Ctx: ctx,
+ Model: billingModel,
+ GroupID: &gid,
+ RequestCount: 1,
+ SizeTier: result.ImageSize,
+ RateMultiplier: multiplier,
+ Resolver: s.resolver,
+ Resolved: resolved,
+ })
+ if err == nil {
+ return cost
+ }
+ logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err)
+ }
+
+ var groupConfig *ImagePriceConfig
+ if apiKey != nil && apiKey.Group != nil {
+ groupConfig = &ImagePriceConfig{
+ Price1K: apiKey.Group.ImagePrice1K,
+ Price2K: apiKey.Group.ImagePrice2K,
+ Price4K: apiKey.Group.ImagePrice4K,
+ }
+ }
+ return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
+}
+
+func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
+ if s.resolver == nil || apiKey == nil || apiKey.Group == nil {
+ return nil
+ }
+ gid := apiKey.Group.ID
+ resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
+ if resolved.Source == PricingSourceChannel {
+ return resolved
+ }
+ return nil
+}
+
+func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool {
+ if result == nil {
+ return false
+ }
+ return result.Usage.InputTokens > 0 ||
+ result.Usage.OutputTokens > 0 ||
+ result.Usage.CacheCreationInputTokens > 0 ||
+ result.Usage.CacheReadInputTokens > 0 ||
+ result.Usage.ImageOutputTokens > 0
+}
+
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go
new file mode 100644
index 00000000..fb6bdc7f
--- /dev/null
+++ b/backend/internal/service/openai_images.go
@@ -0,0 +1,2013 @@
+package service
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "crypto/sha3"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/http"
+ "net/textproto"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "github.com/imroc/req/v3"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+const (
+ openAIImagesGenerationsEndpoint = "/v1/images/generations"
+ openAIImagesEditsEndpoint = "/v1/images/edits"
+
+ openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations"
+ openAIImagesEditsURL = "https://api.openai.com/v1/images/edits"
+
+ openAIChatGPTStartURL = "https://chatgpt.com/"
+ openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files"
+ openAIChatGPTConversationInitURL = "https://chatgpt.com/backend-api/conversation/init"
+ openAIChatGPTConversationURL = "https://chatgpt.com/backend-api/f/conversation"
+ openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare"
+ openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements"
+
+ openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
+ openAIImageRequirementsDiff = "0fffff"
+)
+
+type OpenAIImagesCapability string
+
+const (
+ OpenAIImagesCapabilityBasic OpenAIImagesCapability = "images-basic"
+ OpenAIImagesCapabilityNative OpenAIImagesCapability = "images-native"
+)
+
+type OpenAIImagesUpload struct {
+ FieldName string
+ FileName string
+ ContentType string
+ Data []byte
+ Width int
+ Height int
+}
+
+type OpenAIImagesRequest struct {
+ Endpoint string
+ ContentType string
+ Multipart bool
+ Model string
+ ExplicitModel bool
+ Prompt string
+ Stream bool
+ N int
+ Size string
+ ExplicitSize bool
+ SizeTier string
+ ResponseFormat string
+ HasMask bool
+ HasNativeOptions bool
+ RequiredCapability OpenAIImagesCapability
+ Uploads []OpenAIImagesUpload
+ Body []byte
+ bodyHash string
+}
+
+func (r *OpenAIImagesRequest) IsEdits() bool {
+ return r != nil && r.Endpoint == openAIImagesEditsEndpoint
+}
+
+func (r *OpenAIImagesRequest) StickySessionSeed() string {
+ if r == nil {
+ return ""
+ }
+ parts := []string{
+ "openai-images",
+ strings.TrimSpace(r.Endpoint),
+ strings.TrimSpace(r.Model),
+ strings.TrimSpace(r.Size),
+ strings.TrimSpace(r.Prompt),
+ }
+ seed := strings.Join(parts, "|")
+ if strings.TrimSpace(r.Prompt) == "" && r.bodyHash != "" {
+ seed += "|body=" + r.bodyHash
+ }
+ return seed
+}
+
+func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []byte) (*OpenAIImagesRequest, error) {
+ if c == nil || c.Request == nil {
+ return nil, fmt.Errorf("missing request context")
+ }
+ endpoint := normalizeOpenAIImagesEndpointPath(c.Request.URL.Path)
+ if endpoint == "" {
+ return nil, fmt.Errorf("unsupported images endpoint")
+ }
+
+ contentType := strings.TrimSpace(c.GetHeader("Content-Type"))
+ req := &OpenAIImagesRequest{
+ Endpoint: endpoint,
+ ContentType: contentType,
+ N: 1,
+ Body: body,
+ }
+ if len(body) > 0 {
+ sum := sha256.Sum256(body)
+ req.bodyHash = hex.EncodeToString(sum[:8])
+ }
+
+ mediaType, _, err := mime.ParseMediaType(contentType)
+ if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
+ req.Multipart = true
+ if parseErr := parseOpenAIImagesMultipartRequest(body, contentType, req); parseErr != nil {
+ return nil, parseErr
+ }
+ } else {
+ if len(body) == 0 {
+ return nil, fmt.Errorf("request body is empty")
+ }
+ if !gjson.ValidBytes(body) {
+ return nil, fmt.Errorf("failed to parse request body")
+ }
+ if parseErr := parseOpenAIImagesJSONRequest(body, req); parseErr != nil {
+ return nil, parseErr
+ }
+ }
+
+ applyOpenAIImagesDefaults(req)
+ req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
+ req.RequiredCapability = classifyOpenAIImagesCapability(req)
+ return req, nil
+}
+
+func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error {
+ if modelResult := gjson.GetBytes(body, "model"); modelResult.Exists() {
+ req.Model = strings.TrimSpace(modelResult.String())
+ req.ExplicitModel = req.Model != ""
+ }
+ req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String())
+
+ if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() {
+ if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
+ return fmt.Errorf("invalid stream field type")
+ }
+ req.Stream = streamResult.Bool()
+ }
+
+ if nResult := gjson.GetBytes(body, "n"); nResult.Exists() {
+ if nResult.Type != gjson.Number {
+ return fmt.Errorf("invalid n field type")
+ }
+ req.N = int(nResult.Int())
+ if req.N <= 0 {
+ return fmt.Errorf("n must be greater than 0")
+ }
+ }
+
+ if sizeResult := gjson.GetBytes(body, "size"); sizeResult.Exists() {
+ req.Size = strings.TrimSpace(sizeResult.String())
+ req.ExplicitSize = req.Size != ""
+ }
+ req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String()))
+ req.HasMask = gjson.GetBytes(body, "mask").Exists()
+ req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool {
+ return gjson.GetBytes(body, path).Exists()
+ })
+ return nil
+}
+
+func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *OpenAIImagesRequest) error {
+ _, params, err := mime.ParseMediaType(contentType)
+ if err != nil {
+ return fmt.Errorf("invalid multipart content-type: %w", err)
+ }
+ boundary := strings.TrimSpace(params["boundary"])
+ if boundary == "" {
+ return fmt.Errorf("multipart boundary is required")
+ }
+
+ reader := multipart.NewReader(bytes.NewReader(body), boundary)
+ for {
+ part, err := reader.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("read multipart body: %w", err)
+ }
+ name := strings.TrimSpace(part.FormName())
+ if name == "" {
+ _ = part.Close()
+ continue
+ }
+
+ data, err := io.ReadAll(part)
+ _ = part.Close()
+ if err != nil {
+ return fmt.Errorf("read multipart field %s: %w", name, err)
+ }
+
+ fileName := strings.TrimSpace(part.FileName())
+ if fileName != "" {
+ partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
+ if name == "mask" && len(data) > 0 {
+ req.HasMask = true
+ }
+ if name == "image" || strings.HasPrefix(name, "image[") {
+ width, height := parseOpenAIImageDimensions(part.Header)
+ req.Uploads = append(req.Uploads, OpenAIImagesUpload{
+ FieldName: name,
+ FileName: fileName,
+ ContentType: partContentType,
+ Data: data,
+ Width: width,
+ Height: height,
+ })
+ }
+ continue
+ }
+
+ value := strings.TrimSpace(string(data))
+ switch name {
+ case "model":
+ req.Model = value
+ req.ExplicitModel = value != ""
+ case "prompt":
+ req.Prompt = value
+ case "size":
+ req.Size = value
+ req.ExplicitSize = value != ""
+ case "response_format":
+ req.ResponseFormat = strings.ToLower(value)
+ case "stream":
+ parsed, err := strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid stream field value")
+ }
+ req.Stream = parsed
+ case "n":
+ n, err := strconv.Atoi(value)
+ if err != nil || n <= 0 {
+ return fmt.Errorf("n must be a positive integer")
+ }
+ req.N = n
+ default:
+ if isOpenAINativeImageOption(name) && value != "" {
+ req.HasNativeOptions = true
+ }
+ }
+ }
+
+ if len(req.Uploads) == 0 && req.IsEdits() {
+ return fmt.Errorf("image file is required")
+ }
+ return nil
+}
+
+func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) {
+ return 0, 0
+}
+
+func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) {
+ if req == nil {
+ return
+ }
+ if req.N <= 0 {
+ req.N = 1
+ }
+ if strings.TrimSpace(req.Model) != "" {
+ req.Model = strings.TrimSpace(req.Model)
+ return
+ }
+ req.Model = "gpt-image-2"
+}
+
+func normalizeOpenAIImagesEndpointPath(path string) string {
+ trimmed := strings.TrimSpace(path)
+ switch {
+ case strings.Contains(trimmed, "/images/generations"):
+ return openAIImagesGenerationsEndpoint
+ case strings.Contains(trimmed, "/images/edits"):
+ return openAIImagesEditsEndpoint
+ default:
+ return ""
+ }
+}
+
+func classifyOpenAIImagesCapability(req *OpenAIImagesRequest) OpenAIImagesCapability {
+ if req == nil {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.ExplicitModel || req.ExplicitSize {
+ return OpenAIImagesCapabilityNative
+ }
+ model := strings.ToLower(strings.TrimSpace(req.Model))
+ if !strings.HasPrefix(model, "gpt-image-") {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.Stream || req.N != 1 || req.HasMask || req.HasNativeOptions {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.IsEdits() && !req.Multipart {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.ResponseFormat != "" && req.ResponseFormat != "b64_json" {
+ return OpenAIImagesCapabilityNative
+ }
+ return OpenAIImagesCapabilityBasic
+}
+
+func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
+ for _, path := range []string{
+ "background",
+ "quality",
+ "style",
+ "output_format",
+ "output_compression",
+ "moderation",
+ } {
+ if exists(path) {
+ return true
+ }
+ }
+ return false
+}
+
+func isOpenAINativeImageOption(name string) bool {
+ switch strings.TrimSpace(strings.ToLower(name)) {
+ case "background", "quality", "style", "output_format", "output_compression", "moderation":
+ return true
+ default:
+ return false
+ }
+}
+
+func normalizeOpenAIImageSizeTier(size string) string {
+ switch strings.ToLower(strings.TrimSpace(size)) {
+ case "1024x1024":
+ return "1K"
+ case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
+ return "2K"
+ default:
+ return "2K"
+ }
+}
+
+func (s *OpenAIGatewayService) ForwardImages(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ parsed *OpenAIImagesRequest,
+ channelMappedModel string,
+) (*OpenAIForwardResult, error) {
+ if parsed == nil {
+ return nil, fmt.Errorf("parsed images request is required")
+ }
+ switch account.Type {
+ case AccountTypeAPIKey:
+ return s.forwardOpenAIImagesAPIKey(ctx, c, account, body, parsed, channelMappedModel)
+ case AccountTypeOAuth:
+ return s.forwardOpenAIImagesOAuth(ctx, c, account, parsed, channelMappedModel)
+ default:
+ return nil, fmt.Errorf("unsupported account type: %s", account.Type)
+ }
+}
+
+func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ parsed *OpenAIImagesRequest,
+ channelMappedModel string,
+) (*OpenAIForwardResult, error) {
+ startTime := time.Now()
+ requestModel := strings.TrimSpace(parsed.Model)
+ if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
+ requestModel = mapped
+ }
+ upstreamModel := account.GetMappedModel(requestModel)
+ forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
+ if err != nil {
+ return nil, err
+ }
+ if !parsed.Multipart {
+ setOpsUpstreamRequestBody(c, forwardBody)
+ }
+
+ token, _, err := s.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
+ if err != nil {
+ return nil, err
+ }
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+ upstreamStart := time.Now()
+ resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
+ if err != nil {
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ setOpsUpstreamError(c, 0, safeErr, "")
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: 0,
+ UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
+ Kind: "request_error",
+ Message: safeErr,
+ })
+ return nil, fmt.Errorf("upstream request failed: %s", safeErr)
+ }
+ if resp.StatusCode >= 400 {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+ resp.Body = io.NopCloser(bytes.NewReader(respBody))
+ upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
+ Kind: "failover",
+ Message: upstreamMsg,
+ })
+ s.handleFailoverSideEffects(ctx, resp, account)
+ return nil, &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
+ }
+ }
+ return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ var usage OpenAIUsage
+ imageCount := parsed.N
+ var firstTokenMs *int
+ if parsed.Stream {
+ streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
+ if err != nil {
+ return nil, err
+ }
+ usage = streamUsage
+ imageCount = streamCount
+ firstTokenMs = ttft
+ } else {
+ nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c)
+ if err != nil {
+ return nil, err
+ }
+ usage = nonStreamUsage
+ if nonStreamCount > 0 {
+ imageCount = nonStreamCount
+ }
+ }
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: upstreamModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ }, nil
+}
+
+func (s *OpenAIGatewayService) buildOpenAIImagesRequest(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ contentType string,
+ token string,
+ endpoint string,
+) (*http.Request, error) {
+ targetURL := openAIImagesGenerationsURL
+ if endpoint == openAIImagesEditsEndpoint {
+ targetURL = openAIImagesEditsURL
+ }
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL != "" {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = buildOpenAIImagesURL(validatedURL, endpoint)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+token)
+ for key, values := range c.Request.Header {
+ if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] {
+ continue
+ }
+ for _, value := range values {
+ req.Header.Add(key, value)
+ }
+ }
+ customUA := account.GetOpenAIUserAgent()
+ if customUA != "" {
+ req.Header.Set("User-Agent", customUA)
+ }
+ if strings.TrimSpace(contentType) != "" {
+ req.Header.Set("Content-Type", contentType)
+ }
+ return req, nil
+}
+
+func buildOpenAIImagesURL(base string, endpoint string) string {
+ normalized := strings.TrimRight(strings.TrimSpace(base), "/")
+ relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1")
+ if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
+ return normalized
+ }
+ if strings.HasSuffix(normalized, "/v1") {
+ return normalized + relative
+ }
+ return normalized + endpoint
+}
+
+func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return body, contentType, nil
+ }
+ mediaType, _, err := mime.ParseMediaType(contentType)
+ if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
+ rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model)
+ return rewrittenBody, rewrittenType, rewriteErr
+ }
+ rewritten, err := sjson.SetBytes(body, "model", model)
+ if err != nil {
+ return nil, "", fmt.Errorf("rewrite image request model: %w", err)
+ }
+ return rewritten, contentType, nil
+}
+
+func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) {
+ _, params, err := mime.ParseMediaType(contentType)
+ if err != nil {
+ return nil, "", fmt.Errorf("parse multipart content-type: %w", err)
+ }
+ boundary := strings.TrimSpace(params["boundary"])
+ if boundary == "" {
+ return nil, "", fmt.Errorf("multipart boundary is required")
+ }
+
+ reader := multipart.NewReader(bytes.NewReader(body), boundary)
+ var buffer bytes.Buffer
+ writer := multipart.NewWriter(&buffer)
+ modelWritten := false
+
+ for {
+ part, err := reader.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, "", fmt.Errorf("read multipart body: %w", err)
+ }
+
+ formName := strings.TrimSpace(part.FormName())
+ partHeader := cloneMultipartHeader(part.Header)
+ target, err := writer.CreatePart(partHeader)
+ if err != nil {
+ _ = part.Close()
+ return nil, "", fmt.Errorf("create multipart part: %w", err)
+ }
+
+ if formName == "model" && part.FileName() == "" {
+ if _, err := target.Write([]byte(model)); err != nil {
+ _ = part.Close()
+ return nil, "", fmt.Errorf("rewrite multipart model: %w", err)
+ }
+ modelWritten = true
+ _ = part.Close()
+ continue
+ }
+ if _, err := io.Copy(target, part); err != nil {
+ _ = part.Close()
+ return nil, "", fmt.Errorf("copy multipart part: %w", err)
+ }
+ _ = part.Close()
+ }
+
+ if !modelWritten {
+ if err := writer.WriteField("model", model); err != nil {
+ return nil, "", fmt.Errorf("append multipart model field: %w", err)
+ }
+ }
+ if err := writer.Close(); err != nil {
+ return nil, "", fmt.Errorf("finalize multipart body: %w", err)
+ }
+ return buffer.Bytes(), writer.FormDataContentType(), nil
+}
+
+func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
+ dst := make(textproto.MIMEHeader, len(src))
+ for key, values := range src {
+ copied := make([]string, len(values))
+ copy(copied, values)
+ dst[key] = copied
+ }
+ return dst
+}
+
+func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) {
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
+ if err != nil {
+ return OpenAIUsage{}, 0, err
+ }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ contentType := "application/json"
+ if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
+ if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
+ contentType = upstreamType
+ }
+ }
+ c.Data(resp.StatusCode, contentType, body)
+
+ usage, _ := extractOpenAIUsageFromJSONBytes(body)
+ return usage, extractOpenAIImageCountFromJSONBytes(body), nil
+}
+
+func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
+ resp *http.Response,
+ c *gin.Context,
+ startTime time.Time,
+) (OpenAIUsage, int, *int, error) {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
+ if contentType == "" {
+ contentType = "text/event-stream"
+ }
+ c.Status(resp.StatusCode)
+ c.Header("Content-Type", contentType)
+
+ flusher, ok := c.Writer.(http.Flusher)
+ if !ok {
+ return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
+ }
+
+ reader := bufio.NewReader(resp.Body)
+ usage := OpenAIUsage{}
+ imageCount := 0
+ var firstTokenMs *int
+
+ for {
+ line, err := reader.ReadBytes('\n')
+ if len(line) > 0 {
+ if firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if _, writeErr := c.Writer.Write(line); writeErr != nil {
+ return OpenAIUsage{}, 0, firstTokenMs, writeErr
+ }
+ flusher.Flush()
+
+ if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
+ dataBytes := []byte(data)
+ mergeOpenAIUsage(&usage, dataBytes)
+ if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
+ imageCount = count
+ }
+ }
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return OpenAIUsage{}, 0, firstTokenMs, err
+ }
+ }
+ return usage, imageCount, firstTokenMs, nil
+}
+
+func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
+ if dst == nil {
+ return
+ }
+ if parsed, ok := extractOpenAIUsageFromJSONBytes(body); ok {
+ if parsed.InputTokens > 0 {
+ dst.InputTokens = parsed.InputTokens
+ }
+ if parsed.OutputTokens > 0 {
+ dst.OutputTokens = parsed.OutputTokens
+ }
+ if parsed.CacheReadInputTokens > 0 {
+ dst.CacheReadInputTokens = parsed.CacheReadInputTokens
+ }
+ if parsed.ImageOutputTokens > 0 {
+ dst.ImageOutputTokens = parsed.ImageOutputTokens
+ }
+ }
+}
+
+func extractOpenAIImageCountFromJSONBytes(body []byte) int {
+ if len(body) == 0 || !gjson.ValidBytes(body) {
+ return 0
+ }
+ data := gjson.GetBytes(body, "data")
+ if data.Exists() && data.IsArray() {
+ return len(data.Array())
+ }
+ return 0
+}
+
+func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ parsed *OpenAIImagesRequest,
+ channelMappedModel string,
+) (*OpenAIForwardResult, error) {
+ startTime := time.Now()
+ requestModel := strings.TrimSpace(parsed.Model)
+ if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
+ requestModel = mapped
+ }
+
+ token, _, err := s.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ client, err := newOpenAIBackendAPIClient(resolveOpenAIProxyURL(account))
+ if err != nil {
+ return nil, err
+ }
+ headers, err := s.buildOpenAIBackendAPIHeaders(account, token)
+ if err != nil {
+ return nil, err
+ }
+ if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil {
+ logger.LegacyPrintf("service.openai_gateway", "OpenAI image bootstrap failed: %v", bootstrapErr)
+ }
+
+ chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers)
+ if err != nil {
+ return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
+ }
+ if chatReqs.Arkose.Required {
+ return nil, s.wrapOpenAIImageBackendError(
+ ctx,
+ c,
+ account,
+ newOpenAIImageSyntheticStatusError(
+ http.StatusForbidden,
+ "chat-requirements requires unsupported challenge (arkose)",
+ openAIChatGPTChatRequirementsURL,
+ ),
+ )
+ }
+
+ parentMessageID := uuid.NewString()
+ proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent"))
+ _ = initializeOpenAIImageConversation(ctx, client, headers)
+ conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, parsed.Prompt, parentMessageID, chatReqs.Token, proofToken)
+ if err != nil {
+ return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
+ }
+
+ uploads, err := uploadOpenAIImageFiles(ctx, client, headers, parsed.Uploads)
+ if err != nil {
+ return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
+ }
+
+ convReq := buildOpenAIImageConversationRequest(parsed, parentMessageID, uploads)
+ if parsedContent, err := json.Marshal(convReq); err == nil {
+ setOpsUpstreamRequestBody(c, parsedContent)
+ }
+ convHeaders := cloneHTTPHeader(headers)
+ convHeaders.Set("Accept", "text/event-stream")
+ convHeaders.Set("Content-Type", "application/json")
+ convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token)
+ if conduitToken != "" {
+ convHeaders.Set("x-conduit-token", conduitToken)
+ }
+ if proofToken != "" {
+ convHeaders.Set("openai-sentinel-proof-token", proofToken)
+ }
+
+ resp, err := client.R().
+ SetContext(ctx).
+ DisableAutoReadResponse().
+ SetHeaders(headerToMap(convHeaders)).
+ SetBodyJsonMarshal(convReq).
+ Post(openAIChatGPTConversationURL)
+ if err != nil {
+ return nil, fmt.Errorf("openai image conversation request failed: %w", err)
+ }
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if resp.StatusCode >= 400 {
+ return nil, s.wrapOpenAIImageBackendError(ctx, c, account, handleOpenAIImageBackendError(resp))
+ }
+
+ conversationID, pointerInfos, usage, firstTokenMs, err := readOpenAIImageConversationStream(resp, startTime)
+ if err != nil {
+ return nil, err
+ }
+ pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
+ if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
+ polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
+ if pollErr != nil {
+ return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
+ }
+ pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers)
+ }
+ pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
+ if len(pointerInfos) == 0 {
+ return nil, fmt.Errorf("openai image conversation returned no downloadable images")
+ }
+
+ responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos)
+ if err != nil {
+ return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
+ }
+
+ c.Data(http.StatusOK, "application/json; charset=utf-8", responseBody)
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: requestModel,
+ Stream: false,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ }, nil
+}
+
+func resolveOpenAIProxyURL(account *Account) string {
+ if account != nil && account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+}
+
+func newOpenAIBackendAPIClient(proxyURL string) (*req.Client, error) {
+ client := req.C().
+ SetTimeout(180 * time.Second).
+ ImpersonateChrome()
+ trimmed, _, err := proxyurl.Parse(proxyURL)
+ if err != nil {
+ return nil, err
+ }
+ if trimmed != "" {
+ client.SetProxyURL(trimmed)
+ }
+ return client, nil
+}
+
+func (s *OpenAIGatewayService) buildOpenAIBackendAPIHeaders(account *Account, token string) (http.Header, error) {
+ deviceID, sessionID := s.ensureOpenAIImageSessionCredentials(context.Background(), account)
+ headers := make(http.Header)
+ headers.Set("Authorization", "Bearer "+token)
+ headers.Set("Accept", "application/json")
+ headers.Set("Origin", "https://chatgpt.com")
+ headers.Set("Referer", "https://chatgpt.com/")
+ headers.Set("Sec-Fetch-Dest", "empty")
+ headers.Set("Sec-Fetch-Mode", "cors")
+ headers.Set("Sec-Fetch-Site", "same-origin")
+ headers.Set("User-Agent", openAIImageBackendUserAgent)
+ if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
+ headers.Set("User-Agent", customUA)
+ }
+ if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
+ headers.Set("chatgpt-account-id", chatgptAccountID)
+ }
+ if deviceID != "" {
+ headers.Set("oai-device-id", deviceID)
+ headers.Set("Cookie", "oai-did="+deviceID)
+ }
+ if sessionID != "" {
+ headers.Set("oai-session-id", sessionID)
+ }
+ return headers, nil
+}
+
+func (s *OpenAIGatewayService) ensureOpenAIImageSessionCredentials(ctx context.Context, account *Account) (string, string) {
+ if account == nil {
+ return "", ""
+ }
+ deviceID := account.GetOpenAIDeviceID()
+ sessionID := account.GetOpenAISessionID()
+ if deviceID != "" && sessionID != "" {
+ return deviceID, sessionID
+ }
+
+ updates := map[string]any{}
+ if deviceID == "" {
+ deviceID = uuid.NewString()
+ updates["openai_device_id"] = deviceID
+ }
+ if sessionID == "" {
+ sessionID = uuid.NewString()
+ updates["openai_session_id"] = sessionID
+ }
+ if account.Extra == nil {
+ account.Extra = map[string]any{}
+ }
+ for key, value := range updates {
+ account.Extra[key] = value
+ }
+ if len(updates) == 0 || s == nil || s.accountRepo == nil {
+ return deviceID, sessionID
+ }
+
+ updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+ if err := s.accountRepo.UpdateExtra(updateCtx, account.ID, updates); err != nil {
+ logger.LegacyPrintf("service.openai_gateway", "persist openai image session creds failed: account=%d err=%v", account.ID, err)
+ }
+ return deviceID, sessionID
+}
+
+func bootstrapOpenAIBackendAPI(ctx context.Context, client *req.Client, headers http.Header) error {
+ resp, err := client.R().
+ SetContext(ctx).
+ DisableAutoReadResponse().
+ SetHeaders(headerToMap(headers)).
+ Get(openAIChatGPTStartURL)
+ if err != nil {
+ return err
+ }
+ if resp != nil && resp.Body != nil {
+ _, _ = io.Copy(io.Discard, resp.Body)
+ _ = resp.Body.Close()
+ }
+ return nil
+}
+
+func initializeOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header) error {
+ payload := map[string]any{
+ "gizmo_id": nil,
+ "requested_default_model": nil,
+ "conversation_id": nil,
+ "timezone_offset_min": openAITimezoneOffsetMinutes(),
+ "system_hints": []string{"picture_v2"},
+ }
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ SetBodyJsonMarshal(payload).
+ Post(openAIChatGPTConversationInitURL)
+ if err != nil {
+ return err
+ }
+ if !resp.IsSuccessState() {
+ return newOpenAIImageStatusError(resp, "conversation init failed")
+ }
+ return nil
+}
+
+type openAIChatRequirements struct {
+ Token string `json:"token"`
+ Turnstile struct {
+ Required bool `json:"required"`
+ } `json:"turnstile"`
+ Arkose struct {
+ Required bool `json:"required"`
+ } `json:"arkose"`
+ ProofOfWork struct {
+ Required bool `json:"required"`
+ Seed string `json:"seed"`
+ Difficulty string `json:"difficulty"`
+ } `json:"proofofwork"`
+}
+
+func fetchOpenAIChatRequirements(ctx context.Context, client *req.Client, headers http.Header) (*openAIChatRequirements, error) {
+ var lastErr error
+ for _, payload := range []map[string]any{
+ {"p": nil},
+ {"p": generateOpenAIRequirementsToken(headers.Get("User-Agent"))},
+ } {
+ var result openAIChatRequirements
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ SetBodyJsonMarshal(payload).
+ SetSuccessResult(&result).
+ Post(openAIChatGPTChatRequirementsURL)
+ if err != nil {
+ lastErr = err
+ continue
+ }
+ if resp.IsSuccessState() && strings.TrimSpace(result.Token) != "" {
+ return &result, nil
+ }
+ lastErr = newOpenAIImageStatusError(resp, "chat-requirements failed")
+ }
+ if lastErr == nil {
+ lastErr = fmt.Errorf("chat-requirements failed")
+ }
+ return nil, lastErr
+}
+
+func prepareOpenAIImageConversation(
+ ctx context.Context,
+ client *req.Client,
+ headers http.Header,
+ prompt string,
+ parentMessageID string,
+ chatToken string,
+ proofToken string,
+) (string, error) {
+ messageID := uuid.NewString()
+ payload := map[string]any{
+ "action": "next",
+ "client_prepare_state": "success",
+ "fork_from_shared_post": false,
+ "parent_message_id": parentMessageID,
+ "model": "auto",
+ "timezone_offset_min": openAITimezoneOffsetMinutes(),
+ "timezone": openAITimezoneName(),
+ "conversation_mode": map[string]any{"kind": "primary_assistant"},
+ "system_hints": []string{"picture_v2"},
+ "supports_buffering": true,
+ "supported_encodings": []string{"v1"},
+ "partial_query": map[string]any{
+ "id": messageID,
+ "author": map[string]any{"role": "user"},
+ "content": map[string]any{
+ "content_type": "text",
+ "parts": []string{coalesceOpenAIFileName(prompt, "Generate an image.")},
+ },
+ },
+ "client_contextual_info": map[string]any{
+ "app_name": "chatgpt.com",
+ },
+ }
+ prepareHeaders := cloneHTTPHeader(headers)
+ prepareHeaders.Set("Accept", "*/*")
+ prepareHeaders.Set("Content-Type", "application/json")
+ if strings.TrimSpace(chatToken) != "" {
+ prepareHeaders.Set("openai-sentinel-chat-requirements-token", strings.TrimSpace(chatToken))
+ }
+ if strings.TrimSpace(proofToken) != "" {
+ prepareHeaders.Set("openai-sentinel-proof-token", strings.TrimSpace(proofToken))
+ }
+ var result struct {
+ ConduitToken string `json:"conduit_token"`
+ }
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(prepareHeaders)).
+ SetBodyJsonMarshal(payload).
+ SetSuccessResult(&result).
+ Post(openAIChatGPTConversationPrepareURL)
+ if err != nil {
+ return "", err
+ }
+ if !resp.IsSuccessState() {
+ return "", newOpenAIImageStatusError(resp, "conversation prepare failed")
+ }
+ return strings.TrimSpace(result.ConduitToken), nil
+}
+
+type openAIUploadedImage struct {
+ FileID string
+ FileName string
+ FileSize int
+ MimeType string
+ Width int
+ Height int
+}
+
+func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers http.Header, uploads []OpenAIImagesUpload) ([]openAIUploadedImage, error) {
+ if len(uploads) == 0 {
+ return nil, nil
+ }
+ results := make([]openAIUploadedImage, 0, len(uploads))
+ for i := range uploads {
+ item := uploads[i]
+ fileName := coalesceOpenAIFileName(item.FileName, "image.png")
+ payload := map[string]any{
+ "file_name": fileName,
+ "file_size": len(item.Data),
+ "use_case": "multimodal",
+ }
+ var created struct {
+ FileID string `json:"file_id"`
+ UploadURL string `json:"upload_url"`
+ }
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ SetBodyJsonMarshal(payload).
+ SetSuccessResult(&created).
+ Post(openAIChatGPTFilesURL)
+ if err != nil {
+ return nil, err
+ }
+ if !resp.IsSuccessState() || strings.TrimSpace(created.FileID) == "" || strings.TrimSpace(created.UploadURL) == "" {
+ return nil, newOpenAIImageStatusError(resp, "create upload slot failed")
+ }
+
+ uploadHeaders := map[string]string{
+ "Content-Type": coalesceOpenAIFileName(item.ContentType, "application/octet-stream"),
+ "Origin": "https://chatgpt.com",
+ "x-ms-blob-type": "BlockBlob",
+ "x-ms-version": "2020-04-08",
+ "User-Agent": headers.Get("User-Agent"),
+ }
+ putResp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(uploadHeaders).
+ SetBody(item.Data).
+ DisableAutoReadResponse().
+ Put(created.UploadURL)
+ if err != nil {
+ return nil, err
+ }
+ if putResp.Response != nil && putResp.Body != nil {
+ _, _ = io.Copy(io.Discard, putResp.Body)
+ _ = putResp.Body.Close()
+ }
+ if putResp.StatusCode < 200 || putResp.StatusCode >= 300 {
+ return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed")
+ }
+
+ uploadedResp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ SetBodyJsonMarshal(map[string]any{}).
+ Post(fmt.Sprintf("%s/%s/uploaded", openAIChatGPTFilesURL, created.FileID))
+ if err != nil {
+ return nil, err
+ }
+ if !uploadedResp.IsSuccessState() {
+ return nil, newOpenAIImageStatusError(uploadedResp, "mark upload complete failed")
+ }
+
+ results = append(results, openAIUploadedImage{
+ FileID: created.FileID,
+ FileName: fileName,
+ FileSize: len(item.Data),
+ MimeType: coalesceOpenAIFileName(item.ContentType, "application/octet-stream"),
+ Width: item.Width,
+ Height: item.Height,
+ })
+ }
+ return results, nil
+}
+
+func coalesceOpenAIFileName(value string, fallback string) string {
+ value = strings.TrimSpace(value)
+ if value == "" {
+ return fallback
+ }
+ return value
+}
+
+func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMessageID string, uploads []openAIUploadedImage) map[string]any {
+ parts := []any{coalesceOpenAIFileName(parsed.Prompt, "Generate an image.")}
+ attachments := make([]map[string]any, 0, len(uploads))
+ if len(uploads) > 0 {
+ parts = make([]any, 0, len(uploads)+1)
+ for _, upload := range uploads {
+ parts = append(parts, map[string]any{
+ "content_type": "image_asset_pointer",
+ "asset_pointer": "file-service://" + upload.FileID,
+ "size_bytes": upload.FileSize,
+ "width": upload.Width,
+ "height": upload.Height,
+ })
+ attachment := map[string]any{
+ "id": upload.FileID,
+ "mimeType": upload.MimeType,
+ "name": upload.FileName,
+ "size": upload.FileSize,
+ }
+ if upload.Width > 0 {
+ attachment["width"] = upload.Width
+ }
+ if upload.Height > 0 {
+ attachment["height"] = upload.Height
+ }
+ attachments = append(attachments, attachment)
+ }
+ parts = append(parts, coalesceOpenAIFileName(parsed.Prompt, "Edit this image."))
+ }
+
+ contentType := "text"
+ if len(uploads) > 0 {
+ contentType = "multimodal_text"
+ }
+ metadata := map[string]any{
+ "developer_mode_connector_ids": []any{},
+ "selected_github_repos": []any{},
+ "selected_all_github_repos": false,
+ "system_hints": []string{"picture_v2"},
+ "serialization_metadata": map[string]any{
+ "custom_symbol_offsets": []any{},
+ },
+ }
+ message := map[string]any{
+ "id": uuid.NewString(),
+ "author": map[string]any{"role": "user"},
+ "content": map[string]any{
+ "content_type": contentType,
+ "parts": parts,
+ },
+ "metadata": metadata,
+ "create_time": float64(time.Now().UnixMilli()) / 1000,
+ }
+ if len(attachments) > 0 {
+ metadata["attachments"] = attachments
+ }
+
+ return map[string]any{
+ "action": "next",
+ "client_prepare_state": "sent",
+ "parent_message_id": parentMessageID,
+ "model": "auto",
+ "timezone_offset_min": openAITimezoneOffsetMinutes(),
+ "timezone": openAITimezoneName(),
+ "conversation_mode": map[string]any{"kind": "primary_assistant"},
+ "enable_message_followups": true,
+ "system_hints": []string{"picture_v2"},
+ "supports_buffering": true,
+ "supported_encodings": []string{"v1"},
+ "paragen_cot_summary_display_override": "allow",
+ "force_parallel_switch": "auto",
+ "client_contextual_info": map[string]any{
+ "is_dark_mode": false,
+ "time_since_loaded": 200,
+ "page_height": 900,
+ "page_width": 1440,
+ "pixel_ratio": 1,
+ "screen_height": 1080,
+ "screen_width": 1920,
+ "app_name": "chatgpt.com",
+ },
+ "messages": []any{message},
+ }
+}
+
+type openAIImagePointerInfo struct {
+ Pointer string
+ Prompt string
+}
+
+type openAIImageToolMessage struct {
+ MessageID string
+ CreateTime float64
+ PointerInfos []openAIImagePointerInfo
+}
+
+func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) {
+ if resp == nil || resp.Body == nil {
+ return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response")
+ }
+ reader := bufio.NewReader(resp.Body)
+ var (
+ conversationID string
+ firstTokenMs *int
+ usage OpenAIUsage
+ pointers []openAIImagePointerInfo
+ )
+
+ for {
+ line, err := reader.ReadString('\n')
+ if strings.TrimSpace(line) != "" && firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if data, ok := extractOpenAISSEDataLine(strings.TrimRight(line, "\r\n")); ok && data != "" && data != "[DONE]" {
+ dataBytes := []byte(data)
+ if conversationID == "" {
+ conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "v.conversation_id").String())
+ if conversationID == "" {
+ conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "conversation_id").String())
+ }
+ }
+ mergeOpenAIUsage(&usage, dataBytes)
+ pointers = mergeOpenAIImagePointerInfos(pointers, collectOpenAIImagePointers(dataBytes))
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return "", nil, OpenAIUsage{}, firstTokenMs, err
+ }
+ }
+ return conversationID, pointers, usage, firstTokenMs, nil
+}
+
+func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
+ if len(body) == 0 {
+ return nil
+ }
+ matches := openAIImagePointerMatches(body)
+ if len(matches) == 0 {
+ return nil
+ }
+ prompt := ""
+ for _, path := range []string{
+ "message.metadata.dalle.prompt",
+ "metadata.dalle.prompt",
+ "revised_prompt",
+ } {
+ if value := strings.TrimSpace(gjson.GetBytes(body, path).String()); value != "" {
+ prompt = value
+ break
+ }
+ }
+ out := make([]openAIImagePointerInfo, 0, len(matches))
+ for _, pointer := range matches {
+ out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
+ }
+ return out
+}
+
+func openAIImagePointerMatches(body []byte) []string {
+ raw := string(body)
+ matches := make([]string, 0, 4)
+ for _, prefix := range []string{"file-service://", "sediment://"} {
+ start := 0
+ for {
+ idx := strings.Index(raw[start:], prefix)
+ if idx < 0 {
+ break
+ }
+ idx += start
+ end := idx + len(prefix)
+ for end < len(raw) {
+ ch := raw[end]
+ if ch != '-' && ch != '_' &&
+ (ch < '0' || ch > '9') &&
+ (ch < 'a' || ch > 'z') &&
+ (ch < 'A' || ch > 'Z') {
+ break
+ }
+ end++
+ }
+ matches = append(matches, raw[idx:end])
+ start = end
+ }
+ }
+ return dedupeStrings(matches)
+}
+
+func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []openAIImagePointerInfo) []openAIImagePointerInfo {
+ if len(next) == 0 {
+ return existing
+ }
+ seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
+ out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
+ for _, item := range existing {
+ seen[item.Pointer] = item
+ out = append(out, item)
+ }
+ for _, item := range next {
+ if existingItem, ok := seen[item.Pointer]; ok {
+ if existingItem.Prompt == "" && item.Prompt != "" {
+ for i := range out {
+ if out[i].Pointer == item.Pointer {
+ out[i].Prompt = item.Prompt
+ break
+ }
+ }
+ }
+ continue
+ }
+ seen[item.Pointer] = item
+ out = append(out, item)
+ }
+ return out
+}
+
+func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
+ for _, item := range items {
+ if strings.HasPrefix(item.Pointer, "file-service://") {
+ return true
+ }
+ }
+ return false
+}
+
+func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo {
+ if !hasOpenAIFileServicePointerInfos(items) {
+ return items
+ }
+ out := make([]openAIImagePointerInfo, 0, len(items))
+ for _, item := range items {
+ if strings.HasPrefix(item.Pointer, "file-service://") {
+ out = append(out, item)
+ }
+ }
+ return out
+}
+
+func extractOpenAIImageToolMessages(mapping map[string]any) []openAIImageToolMessage {
+ if len(mapping) == 0 {
+ return nil
+ }
+ out := make([]openAIImageToolMessage, 0, 4)
+ for messageID, raw := range mapping {
+ node, _ := raw.(map[string]any)
+ if node == nil {
+ continue
+ }
+ message, _ := node["message"].(map[string]any)
+ if message == nil {
+ continue
+ }
+ author, _ := message["author"].(map[string]any)
+ metadata, _ := message["metadata"].(map[string]any)
+ content, _ := message["content"].(map[string]any)
+ if author == nil || metadata == nil || content == nil {
+ continue
+ }
+ if role, _ := author["role"].(string); role != "tool" {
+ continue
+ }
+ if asyncTaskType, _ := metadata["async_task_type"].(string); asyncTaskType != "image_gen" {
+ continue
+ }
+ if contentType, _ := content["content_type"].(string); contentType != "multimodal_text" {
+ continue
+ }
+ prompt := ""
+ if title, _ := metadata["image_gen_title"].(string); strings.TrimSpace(title) != "" {
+ prompt = strings.TrimSpace(title)
+ }
+ item := openAIImageToolMessage{MessageID: messageID}
+ if createTime, ok := message["create_time"].(float64); ok {
+ item.CreateTime = createTime
+ }
+ parts, _ := content["parts"].([]any)
+ for _, part := range parts {
+ switch value := part.(type) {
+ case map[string]any:
+ if assetPointer, _ := value["asset_pointer"].(string); strings.TrimSpace(assetPointer) != "" {
+ for _, pointer := range openAIImagePointerMatches([]byte(assetPointer)) {
+ item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{
+ Pointer: pointer,
+ Prompt: prompt,
+ })
+ }
+ }
+ case string:
+ for _, pointer := range openAIImagePointerMatches([]byte(value)) {
+ item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{
+ Pointer: pointer,
+ Prompt: prompt,
+ })
+ }
+ }
+ }
+ if len(item.PointerInfos) == 0 {
+ continue
+ }
+ item.PointerInfos = mergeOpenAIImagePointerInfos(nil, item.PointerInfos)
+ out = append(out, item)
+ }
+ sort.Slice(out, func(i, j int) bool {
+ return out[i].CreateTime < out[j].CreateTime
+ })
+ return out
+}
+
+func pollOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header, conversationID string) ([]openAIImagePointerInfo, error) {
+ conversationID = strings.TrimSpace(conversationID)
+ if conversationID == "" {
+ return nil, nil
+ }
+ deadline := time.Now().Add(90 * time.Second)
+ interval := 3 * time.Second
+ previewWait := 15 * time.Second
+ var (
+ lastErr error
+ firstToolAt time.Time
+ )
+ for time.Now().Before(deadline) {
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ DisableAutoReadResponse().
+ Get(fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s", conversationID))
+ if err != nil {
+ lastErr = err
+ } else {
+ if resp.StatusCode >= 200 && resp.StatusCode < 300 {
+ body, readErr := io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ if readErr != nil {
+ lastErr = readErr
+ goto waitNextPoll
+ }
+ pointers := mergeOpenAIImagePointerInfos(nil, collectOpenAIImagePointers(body))
+ var decoded map[string]any
+ if err := json.Unmarshal(body, &decoded); err == nil {
+ if mapping, _ := decoded["mapping"].(map[string]any); len(mapping) > 0 {
+ toolMessages := extractOpenAIImageToolMessages(mapping)
+ if len(toolMessages) > 0 && firstToolAt.IsZero() {
+ firstToolAt = time.Now()
+ }
+ for _, msg := range toolMessages {
+ pointers = mergeOpenAIImagePointerInfos(pointers, msg.PointerInfos)
+ }
+ }
+ }
+ if hasOpenAIFileServicePointerInfos(pointers) {
+ return preferOpenAIFileServicePointerInfos(pointers), nil
+ }
+ if len(pointers) > 0 && !firstToolAt.IsZero() && time.Since(firstToolAt) >= previewWait {
+ return pointers, nil
+ }
+ } else {
+ statusErr := newOpenAIImageStatusError(resp, "conversation poll failed")
+ if isOpenAIImageTransientConversationNotFoundError(statusErr) {
+ lastErr = statusErr
+ goto waitNextPoll
+ }
+ return nil, statusErr
+ }
+ }
+
+ waitNextPoll:
+ timer := time.NewTimer(interval)
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ return nil, lastErr
+}
+
+func buildOpenAIImageResponse(
+ ctx context.Context,
+ client *req.Client,
+ headers http.Header,
+ conversationID string,
+ pointers []openAIImagePointerInfo,
+) ([]byte, int, error) {
+ type responseItem struct {
+ B64JSON string `json:"b64_json"`
+ RevisedPrompt string `json:"revised_prompt,omitempty"`
+ }
+ items := make([]responseItem, 0, len(pointers))
+ for _, pointer := range pointers {
+ downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
+ if err != nil {
+ return nil, 0, err
+ }
+ data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
+ if err != nil {
+ return nil, 0, err
+ }
+ items = append(items, responseItem{
+ B64JSON: base64.StdEncoding.EncodeToString(data),
+ RevisedPrompt: pointer.Prompt,
+ })
+ }
+ payload := map[string]any{
+ "created": time.Now().Unix(),
+ "data": items,
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, 0, err
+ }
+ return body, len(items), nil
+}
+
+func fetchOpenAIImageDownloadURL(
+ ctx context.Context,
+ client *req.Client,
+ headers http.Header,
+ conversationID string,
+ pointer string,
+) (string, error) {
+ url := ""
+ allowConversationRetry := false
+ switch {
+ case strings.HasPrefix(pointer, "file-service://"):
+ fileID := strings.TrimPrefix(pointer, "file-service://")
+ url = fmt.Sprintf("%s/%s/download", openAIChatGPTFilesURL, fileID)
+ case strings.HasPrefix(pointer, "sediment://"):
+ attachmentID := strings.TrimPrefix(pointer, "sediment://")
+ url = fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s/attachment/%s/download", conversationID, attachmentID)
+ allowConversationRetry = true
+ default:
+ return "", fmt.Errorf("unsupported image pointer: %s", pointer)
+ }
+
+ var lastErr error
+ for attempt := 0; attempt < 8; attempt++ {
+ var result struct {
+ DownloadURL string `json:"download_url"`
+ }
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ SetSuccessResult(&result).
+ Get(url)
+ if err != nil {
+ lastErr = err
+ } else if resp.IsSuccessState() && strings.TrimSpace(result.DownloadURL) != "" {
+ return strings.TrimSpace(result.DownloadURL), nil
+ } else {
+ statusErr := newOpenAIImageStatusError(resp, "fetch image download url failed")
+ if !allowConversationRetry || !isOpenAIImageTransientConversationNotFoundError(statusErr) {
+ return "", statusErr
+ }
+ lastErr = statusErr
+ }
+ if attempt == 7 {
+ break
+ }
+ timer := time.NewTimer(750 * time.Millisecond)
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return "", ctx.Err()
+ case <-timer.C:
+ }
+ }
+ if lastErr == nil {
+ lastErr = fmt.Errorf("fetch image download url failed")
+ }
+ return "", lastErr
+}
+
+func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers http.Header, downloadURL string) ([]byte, error) {
+ request := client.R().
+ SetContext(ctx).
+ DisableAutoReadResponse()
+
+ if strings.HasPrefix(downloadURL, openAIChatGPTStartURL) {
+ downloadHeaders := cloneHTTPHeader(headers)
+ downloadHeaders.Set("Accept", "image/*,*/*;q=0.8")
+ downloadHeaders.Del("Content-Type")
+ request.SetHeaders(headerToMap(downloadHeaders))
+ } else {
+ userAgent := strings.TrimSpace(headers.Get("User-Agent"))
+ if userAgent == "" {
+ userAgent = openAIImageBackendUserAgent
+ }
+ request.SetHeader("User-Agent", userAgent)
+ }
+
+ resp, err := request.Get(downloadURL)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, newOpenAIImageStatusError(resp, "download image bytes failed")
+ }
+ return io.ReadAll(resp.Body)
+}
+
+func handleOpenAIImageBackendError(resp *req.Response) error {
+ return newOpenAIImageStatusError(resp, "backend-api request failed")
+}
+
+type openAIImageStatusError struct {
+ StatusCode int
+ Message string
+ ResponseBody []byte
+ ResponseHeaders http.Header
+ RequestID string
+ URL string
+}
+
+func (e *openAIImageStatusError) Error() string {
+ if e == nil {
+ return "openai image backend request failed"
+ }
+ if e.Message != "" {
+ return e.Message
+ }
+ if e.StatusCode > 0 {
+ return fmt.Sprintf("openai image backend request failed: status %d", e.StatusCode)
+ }
+ return "openai image backend request failed"
+}
+
+func newOpenAIImageStatusError(resp *req.Response, fallback string) error {
+ if resp == nil {
+ if strings.TrimSpace(fallback) == "" {
+ fallback = "openai image backend request failed"
+ }
+ return fmt.Errorf("%s", fallback)
+ }
+
+ statusCode := resp.StatusCode
+ headers := http.Header(nil)
+ requestID := ""
+ requestURL := ""
+ body := []byte(nil)
+
+ if resp.Response != nil {
+ headers = resp.Header.Clone()
+ requestID = strings.TrimSpace(resp.Header.Get("x-request-id"))
+ if resp.Request != nil && resp.Request.URL != nil {
+ requestURL = resp.Request.URL.String()
+ }
+ if resp.Body != nil {
+ body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+ }
+ }
+
+ message := sanitizeUpstreamErrorMessage(extractUpstreamErrorMessage(body))
+ if message == "" {
+ prefix := strings.TrimSpace(fallback)
+ if prefix == "" {
+ prefix = "openai image backend request failed"
+ }
+ message = fmt.Sprintf("%s: status %d", prefix, statusCode)
+ }
+
+ return &openAIImageStatusError{
+ StatusCode: statusCode,
+ Message: message,
+ ResponseBody: body,
+ ResponseHeaders: headers,
+ RequestID: requestID,
+ URL: requestURL,
+ }
+}
+
+func newOpenAIImageSyntheticStatusError(statusCode int, message string, requestURL string) *openAIImageStatusError {
+ message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
+ if message == "" {
+ message = "openai image backend request failed"
+ }
+ var body []byte
+ if payload, err := json.Marshal(map[string]string{"detail": message}); err == nil {
+ body = payload
+ }
+ return &openAIImageStatusError{
+ StatusCode: statusCode,
+ Message: message,
+ ResponseBody: body,
+ URL: strings.TrimSpace(requestURL),
+ }
+}
+
+func isOpenAIImageTransientConversationNotFoundError(err error) bool {
+ statusErr, ok := err.(*openAIImageStatusError)
+ if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound {
+ return false
+ }
+ msg := strings.ToLower(strings.TrimSpace(statusErr.Message))
+ if strings.Contains(msg, "conversation_not_found") {
+ return true
+ }
+ if strings.Contains(msg, "conversation") && strings.Contains(msg, "not found") {
+ return true
+ }
+ bodyMsg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(statusErr.ResponseBody)))
+ if strings.Contains(bodyMsg, "conversation_not_found") {
+ return true
+ }
+ return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found")
+}
+
+func (s *OpenAIGatewayService) wrapOpenAIImageBackendError(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ err error,
+) error {
+ var statusErr *openAIImageStatusError
+ if !errors.As(err, &statusErr) || statusErr == nil {
+ return err
+ }
+
+ upstreamMsg := sanitizeUpstreamErrorMessage(statusErr.Message)
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: statusErr.StatusCode,
+ UpstreamRequestID: statusErr.RequestID,
+ UpstreamURL: safeUpstreamURL(statusErr.URL),
+ Kind: "request_error",
+ Message: upstreamMsg,
+ })
+ setOpsUpstreamError(c, statusErr.StatusCode, upstreamMsg, "")
+
+ if s.shouldFailoverOpenAIUpstreamResponse(statusErr.StatusCode, upstreamMsg, statusErr.ResponseBody) {
+ if s.rateLimitService != nil {
+ s.rateLimitService.HandleUpstreamError(ctx, account, statusErr.StatusCode, statusErr.ResponseHeaders, statusErr.ResponseBody)
+ }
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: statusErr.StatusCode,
+ UpstreamRequestID: statusErr.RequestID,
+ UpstreamURL: safeUpstreamURL(statusErr.URL),
+ Kind: "failover",
+ Message: upstreamMsg,
+ })
+ retryableOnSameAccount := account.IsPoolMode() && isPoolModeRetryableStatus(statusErr.StatusCode)
+ if strings.Contains(strings.ToLower(statusErr.Message), "unsupported challenge") {
+ retryableOnSameAccount = false
+ }
+ return &UpstreamFailoverError{
+ StatusCode: statusErr.StatusCode,
+ ResponseBody: statusErr.ResponseBody,
+ RetryableOnSameAccount: retryableOnSameAccount,
+ }
+ }
+
+ return statusErr
+}
+
+func cloneHTTPHeader(src http.Header) http.Header {
+ dst := make(http.Header, len(src))
+ for key, values := range src {
+ copied := make([]string, len(values))
+ copy(copied, values)
+ dst[key] = copied
+ }
+ return dst
+}
+
+func headerToMap(header http.Header) map[string]string {
+ if len(header) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(header))
+ for key, values := range header {
+ if len(values) == 0 {
+ continue
+ }
+ result[key] = values[0]
+ }
+ return result
+}
+
+func openAITimezoneOffsetMinutes() int {
+ _, offset := time.Now().Zone()
+ return offset / 60
+}
+
+func openAITimezoneName() string {
+ return time.Now().Location().String()
+}
+
+func generateOpenAIRequirementsToken(userAgent string) string {
+ config := []any{
+ "core" + strconv.Itoa(3008),
+ time.Now().UTC().Format(time.RFC1123),
+ nil,
+ 0.123456,
+ coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent),
+ nil,
+ "prod-openai-images",
+ "en-US",
+ "en-US,en",
+ 0,
+ "navigator.webdriver",
+ "location",
+ "document.body",
+ float64(time.Now().UnixMilli()) / 1000,
+ uuid.NewString(),
+ "",
+ 8,
+ time.Now().Unix(),
+ }
+ answer, solved := generateOpenAIChallengeAnswer(strconv.FormatInt(time.Now().UnixNano(), 10), openAIImageRequirementsDiff, config)
+ if solved {
+ return "gAAAAAC" + answer
+ }
+ return ""
+}
+
+func generateOpenAIChallengeAnswer(seed string, difficulty string, config []any) (string, bool) {
+ diffBytes, err := hex.DecodeString(difficulty)
+ if err != nil {
+ return "", false
+ }
+ p1 := []byte(jsonCompactSlice(config[:3], true))
+ p2 := []byte(jsonCompactSlice(config[4:9], false))
+ p3 := []byte(jsonCompactSlice(config[10:], false))
+ seedBytes := []byte(seed)
+
+ for i := 0; i < 100000; i++ {
+ payload := fmt.Sprintf("%s%d,%s,%d,%s", p1, i, p2, i>>1, p3)
+ encoded := base64.StdEncoding.EncodeToString([]byte(payload))
+ sum := sha3.Sum512(append(seedBytes, []byte(encoded)...))
+ if bytes.Compare(sum[:len(diffBytes)], diffBytes) <= 0 {
+ return encoded, true
+ }
+ }
+ return "", false
+}
+
+func jsonCompactSlice(values []any, trimSuffixComma bool) string {
+ raw, _ := json.Marshal(values)
+ text := string(raw)
+ if trimSuffixComma {
+ return strings.TrimSuffix(text, "]")
+ }
+ return strings.TrimPrefix(text, "[")
+}
+
+func generateOpenAIProofToken(required bool, seed string, difficulty string, userAgent string) string {
+ if !required || strings.TrimSpace(seed) == "" || strings.TrimSpace(difficulty) == "" {
+ return ""
+ }
+ screen := 3008
+ if len(seed)%2 == 0 {
+ screen = 4010
+ }
+ proofToken := []any{
+ screen,
+ time.Now().UTC().Format(time.RFC1123),
+ nil,
+ 0,
+ coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent),
+ "https://chatgpt.com/",
+ "dpl=openai-images",
+ "en",
+ "en-US",
+ nil,
+ "plugins[object PluginArray]",
+ "_reactListening",
+ "alert",
+ }
+ diffLen := len(difficulty)
+ for i := 0; i < 100000; i++ {
+ proofToken[3] = i
+ raw, _ := json.Marshal(proofToken)
+ encoded := base64.StdEncoding.EncodeToString(raw)
+ sum := sha3.Sum512([]byte(seed + encoded))
+ if strings.Compare(hex.EncodeToString(sum[:])[:diffLen], difficulty) <= 0 {
+ return "gAAAAAB" + encoded
+ }
+ }
+ fallbackBase := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%q", seed)))
+ return "gAAAAABwQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + fallbackBase
+}
+
+func dedupeStrings(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(values))
+ out := make([]string, 0, len(values))
+ for _, value := range values {
+ if _, ok := seen[value]; ok {
+ continue
+ }
+ seen[value] = struct{}{}
+ out = append(out, value)
+ }
+ return out
+}
diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go
new file mode 100644
index 00000000..173d69ba
--- /dev/null
+++ b/backend/internal/service/openai_images_test.go
@@ -0,0 +1,105 @@
+package service
+
+import (
+ "bytes"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "/v1/images/generations", parsed.Endpoint)
+ require.Equal(t, "gpt-image-2", parsed.Model)
+ require.Equal(t, "draw a cat", parsed.Prompt)
+ require.True(t, parsed.Stream)
+ require.Equal(t, "1024x1024", parsed.Size)
+ require.Equal(t, "1K", parsed.SizeTier)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+ require.False(t, parsed.Multipart)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ require.NoError(t, writer.WriteField("model", "gpt-image-2"))
+ require.NoError(t, writer.WriteField("prompt", "replace background"))
+ require.NoError(t, writer.WriteField("size", "1536x1024"))
+ part, err := writer.CreateFormFile("image", "source.png")
+ require.NoError(t, err)
+ _, err = part.Write([]byte("fake-image-bytes"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "/v1/images/edits", parsed.Endpoint)
+ require.True(t, parsed.Multipart)
+ require.Equal(t, "gpt-image-2", parsed.Model)
+ require.Equal(t, "replace background", parsed.Prompt)
+ require.Equal(t, "1536x1024", parsed.Size)
+ require.Equal(t, "2K", parsed.SizeTier)
+ require.Len(t, parsed.Uploads, 1)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"prompt":"draw a cat"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "gpt-image-2", parsed.Model)
+ require.Equal(t, OpenAIImagesCapabilityBasic, parsed.RequiredCapability)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNativeCapability(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"prompt":"draw a cat","size":"1024x1024"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+}
diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go
index cda7e369..35e7c250 100644
--- a/backend/internal/service/openai_model_mapping_test.go
+++ b/backend/internal/service/openai_model_mapping_test.go
@@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
}
}
-func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
+func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) {
account := &Account{
Credentials: map[string]any{},
}
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
- if withoutDefault != "gpt-5.1" {
- t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
+ if withoutDefault != "gpt-5.4" {
+ t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
}
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
@@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3": "gpt-5.3-codex",
}
@@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
name: "oauth keeps codex normalization behavior",
account: &Account{Type: AccountTypeOAuth},
model: "gemini-3-flash-preview",
- want: "gpt-5.1",
+ want: "gpt-5.4",
},
{
name: "apikey preserves custom compatible model",
diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go
index c0e814ab..bd40d389 100644
--- a/backend/internal/service/ops_retry.go
+++ b/backend/internal/service/ops_retry.go
@@ -388,7 +388,7 @@ func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDeta
func detectOpsRetryType(path string) opsRetryRequestType {
p := strings.ToLower(strings.TrimSpace(path))
switch {
- case strings.Contains(p, "/responses"):
+ case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
return opsRetryTypeOpenAI
case strings.Contains(p, "/v1beta/"):
return opsRetryTypeGeminiV1B
diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go
index 56905278..973c601a 100644
--- a/backend/internal/service/payment_config_limits.go
+++ b/backend/internal/service/payment_config_limits.go
@@ -20,6 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
+ typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
@@ -31,6 +32,41 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
+func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
+ if len(typeInstances) == 0 {
+ return typeInstances
+ }
+
+ filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances))
+ for paymentType, groupedInstances := range typeInstances {
+ filtered[paymentType] = groupedInstances
+ }
+
+ for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ matching := filterEnabledVisibleMethodInstances(instances, method)
+ providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
+ if err != nil {
+ delete(filtered, method)
+ continue
+ }
+ if providerKey == "" {
+ if len(matching) == 0 {
+ delete(filtered, method)
+ continue
+ }
+ filtered[method] = matching
+ continue
+ }
+ selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
+ if len(selectedInstances) == 0 {
+ delete(filtered, method)
+ continue
+ }
+ filtered[method] = selectedInstances
+ }
+ return filtered
+}
+
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go
index 73ad66ef..4df506d6 100644
--- a/backend/internal/service/payment_config_limits_test.go
+++ b/backend/internal/service/payment_config_limits_test.go
@@ -1,10 +1,12 @@
package service
import (
+ "context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
)
func TestUnionFloat(t *testing.T) {
@@ -299,3 +301,161 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}
})
}
+
+func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
+ tests := []struct {
+ name string
+ sourceSetting string
+ wantAlipaySingleMin float64
+ wantAlipaySingleMax float64
+ wantGlobalMin float64
+ wantGlobalMax float64
+ }{
+ {
+ name: "official source",
+ sourceSetting: VisibleMethodSourceOfficialAlipay,
+ wantAlipaySingleMin: 10,
+ wantAlipaySingleMax: 100,
+ wantGlobalMin: 10,
+ wantGlobalMax: 300,
+ },
+ {
+ name: "easypay source",
+ sourceSetting: VisibleMethodSourceEasyPayAlipay,
+ wantAlipaySingleMin: 20,
+ wantAlipaySingleMax: 200,
+ wantGlobalMin: 20,
+ wantGlobalMax: 300,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
+ },
+ },
+ }
+
+ resp, err := svc.GetAvailableMethodLimits(ctx)
+ if err != nil {
+ t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
+ }
+
+ alipayLimits, ok := resp.Methods[payment.TypeAlipay]
+ if !ok {
+ t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
+ }
+ if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
+ t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
+ }
+
+ wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
+ if !ok {
+ t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
+ }
+ if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
+ t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
+ }
+ if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
+ t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
+ }
+ })
+ }
+}
+
+func TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Mixed").
+ SetConfig("{}").
+ SetSupportedTypes("alipay,wxpay").
+ SetLimits(`{"alipay":{"singleMin":20,"singleMax":200},"wxpay":{"singleMin":40,"singleMax":400}}`).
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{}},
+ }
+
+ resp, err := svc.GetAvailableMethodLimits(ctx)
+ require.NoError(t, err)
+
+ alipayLimits, ok := resp.Methods[payment.TypeAlipay]
+ require.True(t, ok, "expected alipay limits to remain visible")
+ require.Equal(t, 10.0, alipayLimits.SingleMin)
+ require.Equal(t, 200.0, alipayLimits.SingleMax)
+
+ wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
+ require.True(t, ok, "expected wxpay limits to remain visible")
+ require.Equal(t, 30.0, wxpayLimits.SingleMin)
+ require.Equal(t, 400.0, wxpayLimits.SingleMax)
+
+ require.Equal(t, 10.0, resp.GlobalMin)
+ require.Equal(t, 400.0, resp.GlobalMax)
+}
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 3c406b45..ff05e559 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"strconv"
"strings"
@@ -11,9 +12,22 @@ import (
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/payment/provider"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
+// validateProviderConfig runs the provider's constructor to surface config-level
+// errors at save time (e.g. wxpay missing certSerial), instead of only failing
+// when an order is created. Returns the structured ApplicationError from the
+// constructor so the frontend i18n layer can localize it.
+//
+// Only validates enabled instances — a disabled instance may be a half-filled
+// draft the admin will complete later.
+func (s *PaymentConfigService) validateProviderConfig(providerKey string, config map[string]string) error {
+ _, err := provider.CreateProvider(providerKey, "_validate_", config)
+ return err
+}
+
// --- Provider Instance CRUD ---
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
@@ -47,11 +61,10 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
- Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
- AllowUserRefund: inst.AllowUserRefund,
- SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
+ Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund,
+ SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
- resp.Config, err = s.decryptAndMaskConfig(inst.Config)
+ resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
}
@@ -60,8 +73,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
return result, nil
}
-func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
- return s.decryptConfig(encrypted)
+// decryptAndMaskConfig returns the stored config with sensitive fields omitted.
+// Admin UIs display masked placeholders for these; the raw values never leave
+// the server. Callers that need the full config (e.g. payment runtime) must
+// use decryptConfig directly.
+func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
+ cfg, err := s.decryptConfig(encrypted)
+ if err != nil {
+ return nil, err
+ }
+ if cfg == nil {
+ return nil, nil
+ }
+ masked := make(map[string]string, len(cfg))
+ for k, v := range cfg {
+ if isSensitiveProviderConfigField(providerKey, k) {
+ continue
+ }
+ masked[k] = v
+ }
+ return masked, nil
}
// pendingOrderStatuses are order statuses considered "in progress".
@@ -71,18 +102,62 @@ var pendingOrderStatuses = []string{
payment.OrderStatusRecharging,
}
-var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
+// providerSensitiveConfigFields is the authoritative list of config keys that
+// are treated as secrets per provider. Must stay in sync with the frontend
+// definition at frontend/src/components/payment/providerConfig.ts
+// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
+//
+// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
+// stripe publishableKey) are returned in plaintext by the admin GET API.
+var providerSensitiveConfigFields = map[string]map[string]struct{}{
+ payment.TypeEasyPay: {"pkey": {}},
+ payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
+ payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
+ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
+}
-func isSensitiveConfigField(fieldName string) bool {
- lower := strings.ToLower(fieldName)
- for _, p := range sensitiveConfigPatterns {
- if strings.Contains(lower, p) {
+// providerPendingOrderProtectedConfigFields lists config keys that cannot be
+// changed while the instance has in-progress orders. This includes secrets plus
+// all provider identity fields that are snapshotted into orders or used by
+// webhook/refund verification.
+var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
+ payment.TypeEasyPay: {"pkey": {}, "pid": {}},
+ payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
+ payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
+ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
+}
+
+func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
+ fields, ok := providerSensitiveConfigFields[providerKey]
+ if !ok {
+ return false
+ }
+ _, found := fields[strings.ToLower(fieldName)]
+ return found
+}
+
+func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool {
+ fields, ok := providerPendingOrderProtectedConfigFields[providerKey]
+ if !ok {
+ return false
+ }
+ for fieldName := range fields {
+ if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) {
return true
}
}
return false
}
+func providerConfigFieldValue(config map[string]string, fieldName string) string {
+ for key, value := range config {
+ if strings.EqualFold(key, fieldName) {
+ return value
+ }
+ }
+ return ""
+}
+
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
@@ -108,6 +183,14 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil {
return nil, err
}
+ if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil {
+ return nil, err
+ }
+ if req.Enabled {
+ if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil {
+ return nil, err
+ }
+ }
enc, err := s.encryptConfig(req.Config)
if err != nil {
return nil, err
@@ -136,18 +219,47 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
- if req.Config != nil {
- hasSensitive := false
- for k := range req.Config {
- if isSensitiveConfigField(k) && req.Config[k] != "" {
- hasSensitive = true
- break
- }
+ current, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance: %w", err)
+ }
+ var pendingOrderCount *int
+ getPendingOrderCount := func() (int, error) {
+ if pendingOrderCount != nil {
+ return *pendingOrderCount, nil
}
- if hasSensitive {
- count, err := s.countPendingOrders(ctx, id)
+ count, err := s.countPendingOrders(ctx, id)
+ if err != nil {
+ return 0, fmt.Errorf("check pending orders: %w", err)
+ }
+ pendingOrderCount = &count
+ return count, nil
+ }
+ nextEnabled := current.Enabled
+ if req.Enabled != nil {
+ nextEnabled = *req.Enabled
+ }
+ nextSupportedTypes := current.SupportedTypes
+ if req.SupportedTypes != nil {
+ nextSupportedTypes = joinTypes(req.SupportedTypes)
+ }
+ if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil {
+ return nil, err
+ }
+ var mergedConfig map[string]string
+ if req.Config != nil {
+ currentConfig, err := s.decryptConfig(current.Config)
+ if err != nil {
+ return nil, fmt.Errorf("decrypt existing config: %w", err)
+ }
+ mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
+ if err != nil {
+ return nil, err
+ }
+ if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) {
+ count, err := getPendingOrderCount()
if err != nil {
- return nil, fmt.Errorf("check pending orders: %w", err)
+ return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
@@ -156,25 +268,40 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if req.Enabled != nil && !*req.Enabled {
- count, err := s.countPendingOrders(ctx, id)
+ count, err := getPendingOrderCount()
if err != nil {
- return nil, fmt.Errorf("check pending orders: %w", err)
+ return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
+ // Validate merged config when the instance will end up enabled.
+ // This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time,
+ // so admins see them in the dialog instead of only when an order is created.
+ finalEnabled := current.Enabled
+ if req.Enabled != nil {
+ finalEnabled = *req.Enabled
+ }
+ if finalEnabled {
+ configToValidate := mergedConfig
+ if configToValidate == nil {
+ configToValidate, err = s.decryptConfig(current.Config)
+ if err != nil {
+ return nil, fmt.Errorf("decrypt existing config: %w", err)
+ }
+ }
+ if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil {
+ return nil, err
+ }
+ }
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
if req.Name != nil {
u.SetName(*req.Name)
}
- if req.Config != nil {
- merged, err := s.mergeConfig(ctx, id, req.Config)
- if err != nil {
- return nil, err
- }
- enc, err := s.encryptConfig(merged)
+ if mergedConfig != nil {
+ enc, err := s.encryptConfig(mergedConfig)
if err != nil {
return nil, err
}
@@ -182,17 +309,13 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
- count, err := s.countPendingOrders(ctx, id)
+ count, err := getPendingOrderCount()
if err != nil {
- return nil, fmt.Errorf("check pending orders: %w", err)
+ return nil, err
}
if count > 0 {
// Load current instance to compare types
- inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
- if err != nil {
- return nil, fmt.Errorf("load provider instance: %w", err)
- }
- oldTypes := strings.Split(inst.SupportedTypes, ",")
+ oldTypes := strings.Split(current.SupportedTypes, ",")
newTypes := req.SupportedTypes
for _, ot := range oldTypes {
ot = strings.TrimSpace(ot)
@@ -237,10 +360,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if req.RefundEnabled != nil {
refundEnabled = *req.RefundEnabled
} else {
- inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
- if err == nil {
- refundEnabled = inst.RefundEnabled
- }
+ refundEnabled = current.RefundEnabled
}
if refundEnabled {
u.SetAllowUserRefund(true)
@@ -282,27 +402,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
}
if existing == nil {
- return newConfig, nil
+ existing = map[string]string{}
}
for k, v := range newConfig {
+ // Preserve existing secrets when the client submits an empty value
+ // (admin UI omits the value to indicate "leave unchanged").
+ if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
+ continue
+ }
existing[k] = v
}
return existing, nil
}
-func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
- if encrypted == "" {
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
+// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
+// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
+// letting the admin re-enter the config via the UI to complete the migration.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
+// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
+// a few releases once all live deployments have re-saved their provider configs.
+func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
return nil, nil
}
- decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
- if err != nil {
- return nil, fmt.Errorf("decrypt config: %w", err)
+ var cfg map[string]string
+ if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
+ return cfg, nil
}
- var raw map[string]string
- if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
- return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
+ if len(s.encryptionKey) == payment.AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
+ if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
+ return cfg, nil
+ }
+ }
}
- return raw, nil
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
@@ -317,14 +458,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
+// encryptConfig serialises a provider config for storage.
+// New records are written as plaintext JSON; the historical AES-GCM wrapping
+// has been dropped but decryptConfig still accepts old ciphertext during migration.
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshal config: %w", err)
}
- enc, err := payment.Encrypt(string(data), s.encryptionKey)
- if err != nil {
- return "", fmt.Errorf("encrypt config: %w", err)
- }
- return enc, nil
+ return string(data), nil
}
diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go
index 2aaa874f..e0d2908a 100644
--- a/backend/internal/service/payment_config_providers_test.go
+++ b/backend/internal/service/payment_config_providers_test.go
@@ -3,8 +3,18 @@
package service
import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "strconv"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -97,41 +107,52 @@ func TestValidateProviderRequest(t *testing.T) {
}
}
-func TestIsSensitiveConfigField(t *testing.T) {
+func TestIsSensitiveProviderConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
- field string
- wantSen bool
+ providerKey string
+ field string
+ wantSen bool
}{
- // Sensitive fields (contain key/secret/private/password/pkey patterns)
- {"secretKey", true},
- {"apiSecret", true},
- {"pkey", true},
- {"privateKey", true},
- {"apiPassword", true},
- {"appKey", true},
- {"SECRET_TOKEN", true},
- {"PrivateData", true},
- {"PASSWORD", true},
- {"mySecretValue", true},
+ // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
+ {"stripe", "secretKey", true},
+ {"stripe", "webhookSecret", true},
+ {"stripe", "SecretKey", true}, // case-insensitive
+ {"stripe", "publishableKey", false},
+ {"stripe", "appId", false},
- // Non-sensitive fields
- {"appId", false},
- {"mchId", false},
- {"apiBase", false},
- {"endpoint", false},
- {"merchantNo", false},
- {"paymentMode", false},
- {"notifyUrl", false},
+ // Alipay
+ {"alipay", "privateKey", true},
+ {"alipay", "publicKey", true},
+ {"alipay", "alipayPublicKey", true},
+ {"alipay", "appId", false},
+ {"alipay", "notifyUrl", false},
+
+ // Wxpay
+ {"wxpay", "privateKey", true},
+ {"wxpay", "apiV3Key", true},
+ {"wxpay", "publicKey", true},
+ {"wxpay", "publicKeyId", false},
+ {"wxpay", "certSerial", false},
+ {"wxpay", "mchId", false},
+
+ // EasyPay
+ {"easypay", "pkey", true},
+ {"easypay", "pid", false},
+ {"easypay", "apiBase", false},
+
+ // Unknown provider: never sensitive
+ {"unknown", "secretKey", false},
}
for _, tc := range tests {
- t.Run(tc.field, func(t *testing.T) {
+ tc := tc
+ t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
t.Parallel()
- got := isSensitiveConfigField(tc.field)
- assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
+ got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
+ assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
})
}
}
@@ -185,3 +206,403 @@ func TestJoinTypes(t *testing.T) {
})
}
}
+
+func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ _, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "easypay",
+ Name: "EasyPay Alipay",
+ Config: map[string]string{
+ "pid": "1001",
+ "pkey": "pkey-1001",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/notify",
+ "returnUrl": "https://merchant.example.com/return",
+ },
+ SupportedTypes: []string{"alipay"},
+ Enabled: true,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "alipay",
+ Name: "Official Alipay",
+ Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
+ SupportedTypes: []string{"alipay"},
+ Enabled: true,
+ })
+ require.NoError(t, err)
+}
+
+func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ existing, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "easypay",
+ Name: "EasyPay WeChat",
+ Config: map[string]string{
+ "pid": "2001",
+ "pkey": "pkey-2001",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/notify",
+ "returnUrl": "https://merchant.example.com/return",
+ },
+ SupportedTypes: []string{"wxpay"},
+ Enabled: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, existing)
+
+ candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "wxpay",
+ Name: "Official WeChat",
+ Config: validWxpayProviderConfig(t),
+ SupportedTypes: []string{"wxpay"},
+ Enabled: false,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
+ Enabled: boolPtrValue(true),
+ })
+ require.NoError(t, err)
+}
+
+func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "easypay",
+ Name: "EasyPay",
+ Config: map[string]string{
+ "pid": "3001",
+ "pkey": "pkey-3001",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/notify",
+ "returnUrl": "https://merchant.example.com/return",
+ },
+ SupportedTypes: []string{"alipay"},
+ Enabled: false,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
+ Enabled: boolPtrValue(true),
+ SupportedTypes: []string{"alipay", "wxpay"},
+ })
+ require.NoError(t, err)
+
+ saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
+ require.NoError(t, err)
+ require.True(t, saved.Enabled)
+ require.Equal(t, "alipay,wxpay", saved.SupportedTypes)
+}
+
+func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ providerKey string
+ createConfig func(*testing.T) map[string]string
+ supportedType []string
+ updateConfig map[string]string
+ fieldName string
+ wantValue string
+ }{
+ {
+ name: "wxpay appId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"appId": "wx-app-updated"},
+ fieldName: "appId",
+ wantValue: "wx-app-test",
+ },
+ {
+ name: "wxpay mpAppId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfigWithJSAPIAppID,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"},
+ fieldName: "mpAppId",
+ wantValue: "wx-mp-app-test",
+ },
+ {
+ name: "wxpay mchId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"mchId": "mch-updated"},
+ fieldName: "mchId",
+ wantValue: "mch-test",
+ },
+ {
+ name: "wxpay publicKeyId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"},
+ fieldName: "publicKeyId",
+ wantValue: "public-key-id-test",
+ },
+ {
+ name: "wxpay certSerial",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"certSerial": "cert-serial-updated"},
+ fieldName: "certSerial",
+ wantValue: "cert-serial-test",
+ },
+ {
+ name: "alipay appId",
+ providerKey: payment.TypeAlipay,
+ createConfig: validAlipayProviderConfig,
+ supportedType: []string{payment.TypeAlipay},
+ updateConfig: map[string]string{"appId": "alipay-app-updated"},
+ fieldName: "appId",
+ wantValue: "alipay-app-test",
+ },
+ {
+ name: "easypay pid",
+ providerKey: payment.TypeEasyPay,
+ createConfig: validEasyPayProviderConfig,
+ supportedType: []string{payment.TypeAlipay},
+ updateConfig: map[string]string{"pid": "pid-updated"},
+ fieldName: "pid",
+ wantValue: "pid-test",
+ },
+ }
+
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: tc.providerKey,
+ Name: "protected-config-instance",
+ Config: tc.createConfig(t),
+ SupportedTypes: tc.supportedType,
+ Enabled: true,
+ })
+ require.NoError(t, err)
+
+ createPendingProviderConfigOrder(t, ctx, client, instance)
+
+ updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
+ Config: tc.updateConfig,
+ })
+ require.Nil(t, updated)
+ require.Error(t, err)
+ require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err))
+
+ saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
+ require.NoError(t, err)
+ cfg, err := svc.decryptConfig(saved.Config)
+ require.NoError(t, err)
+ require.Equal(t, tc.wantValue, cfg[tc.fieldName])
+ })
+ }
+}
+
+func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ providerKey string
+ createConfig func(*testing.T) map[string]string
+ supportedType []string
+ updateConfig map[string]string
+ fieldName string
+ wantValue string
+ }{
+ {
+ name: "wxpay notifyUrl",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"},
+ fieldName: "notifyUrl",
+ wantValue: "https://merchant.example.com/wxpay/notify-v2",
+ },
+ {
+ name: "alipay same appId",
+ providerKey: payment.TypeAlipay,
+ createConfig: validAlipayProviderConfig,
+ supportedType: []string{payment.TypeAlipay},
+ updateConfig: map[string]string{"appId": "alipay-app-test"},
+ fieldName: "appId",
+ wantValue: "alipay-app-test",
+ },
+ }
+
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: tc.providerKey,
+ Name: "safe-config-instance",
+ Config: tc.createConfig(t),
+ SupportedTypes: tc.supportedType,
+ Enabled: true,
+ })
+ require.NoError(t, err)
+
+ createPendingProviderConfigOrder(t, ctx, client, instance)
+
+ updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
+ Config: tc.updateConfig,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+
+ saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
+ require.NoError(t, err)
+ cfg, err := svc.decryptConfig(saved.Config)
+ require.NoError(t, err)
+ require.Equal(t, tc.wantValue, cfg[tc.fieldName])
+ })
+ }
+}
+
+func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
+ t.Helper()
+
+ user, err := client.User.Create().
+ SetEmail("provider-config-pending@example.com").
+ SetPasswordHash("hash").
+ SetUsername("provider-config-pending-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instanceID := strconv.FormatInt(instance.ID, 10)
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID).
+ SetOutTradeNo("sub2_pending_provider_config_" + instanceID).
+ SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instanceID).
+ SetProviderKey(instance.ProviderKey).
+ Save(ctx)
+ require.NoError(t, err)
+}
+
+func providerPendingOrderPaymentType(providerKey string) string {
+ switch providerKey {
+ case payment.TypeWxpay:
+ return payment.TypeWxpay
+ case payment.TypeAlipay:
+ return payment.TypeAlipay
+ default:
+ return payment.TypeAlipay
+ }
+}
+
+func boolPtrValue(v bool) *bool {
+ return &v
+}
+
+func validAlipayProviderConfig(t *testing.T) map[string]string {
+ t.Helper()
+
+ return map[string]string{
+ "appId": "alipay-app-test",
+ "privateKey": "alipay-private-key-test",
+ "notifyUrl": "https://merchant.example.com/alipay/notify",
+ "returnUrl": "https://merchant.example.com/alipay/return",
+ }
+}
+
+func validEasyPayProviderConfig(t *testing.T) map[string]string {
+ t.Helper()
+
+ return map[string]string{
+ "pid": "pid-test",
+ "pkey": "pkey-test",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/easypay/notify",
+ "returnUrl": "https://merchant.example.com/easypay/return",
+ }
+}
+
+func validWxpayProviderConfig(t *testing.T) map[string]string {
+ t.Helper()
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ require.NoError(t, err)
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ require.NoError(t, err)
+
+ return map[string]string{
+ "appId": "wx-app-test",
+ "mchId": "mch-test",
+ "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ "apiV3Key": "12345678901234567890123456789012",
+ "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
+ "publicKeyId": "public-key-id-test",
+ "certSerial": "cert-serial-test",
+ }
+}
+
+func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string {
+ t.Helper()
+
+ cfg := validWxpayProviderConfig(t)
+ cfg["mpAppId"] = "wx-mp-app-test"
+ return cfg
+}
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index 59764b29..02d061ae 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct {
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
+
+ VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
}
// MethodLimits holds per-payment-type limits.
@@ -196,6 +201,8 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
+ SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
+ SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
@@ -234,18 +241,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
+ types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
- cfg.EnabledTypes = append(cfg.EnabledTypes, t)
+ types = append(types, t)
}
}
+ cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
+ if s.entClient == nil {
+ return ""
+ }
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
@@ -282,25 +294,29 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
}
}
m := map[string]string{
- SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
- SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
- SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
- SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
- SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
- SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
- SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
- SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
- SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
- SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
- SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
- SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
- SettingHelpImageURL: derefStr(req.HelpImageURL),
- SettingHelpText: derefStr(req.HelpText),
- SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
- SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
- SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
- SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
- SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
+ SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
+ SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
+ SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
+ SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
+ SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
+ SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
+ SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
+ SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
+ SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
+ SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
+ SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
+ SettingHelpImageURL: derefStr(req.HelpImageURL),
+ SettingHelpText: derefStr(req.HelpText),
+ SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
+ SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
+ SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
+ SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
+ SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource),
+ SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource),
+ SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled),
+ SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
@@ -385,3 +401,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
+
+func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
+ available := make(map[string]bool, 4)
+ for _, inst := range instances {
+ switch inst.ProviderKey {
+ case payment.TypeAlipay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
+ available[VisibleMethodSourceOfficialAlipay] = true
+ }
+ case payment.TypeWxpay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
+ available[VisibleMethodSourceOfficialWechat] = true
+ }
+ case payment.TypeEasyPay:
+ for _, supportedType := range splitTypes(inst.SupportedTypes) {
+ switch NormalizeVisibleMethod(supportedType) {
+ case payment.TypeAlipay:
+ available[VisibleMethodSourceEasyPayAlipay] = true
+ case payment.TypeWxpay:
+ available[VisibleMethodSourceEasyPayWechat] = true
+ }
+ }
+ }
+ }
+ return available
+}
+
+func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
+ shouldExpose := map[string]bool{
+ payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
+ payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
+ }
+
+ seen := make(map[string]struct{}, len(base)+2)
+ out := make([]string, 0, len(base)+2)
+ appendType := func(paymentType string) {
+ paymentType = NormalizeVisibleMethod(paymentType)
+ if paymentType == "" {
+ return
+ }
+ if _, ok := seen[paymentType]; ok {
+ return
+ }
+ seen[paymentType] = struct{}{}
+ out = append(out, paymentType)
+ }
+
+ for _, paymentType := range base {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ switch visibleMethod {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ default:
+ appendType(visibleMethod)
+ }
+ }
+
+ for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ }
+ return out
+}
+
+func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
+ enabledKey := visibleMethodEnabledSettingKey(method)
+ sourceKey := visibleMethodSourceSettingKey(method)
+ if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
+ return false
+ }
+ source := NormalizeVisibleMethodSource(method, vals[sourceKey])
+ return source != "" && available[source]
+}
diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go
index 027bb796..f04f4697 100644
--- a/backend/internal/service/payment_config_service_test.go
+++ b/backend/internal/service/payment_config_service_test.go
@@ -1,9 +1,19 @@
package service
import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
@@ -163,6 +173,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
+ t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
+ t.Parallel()
+ vals := map[string]string{
+ SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
+ }
+ cfg := svc.parsePaymentConfig(vals)
+ if len(cfg.EnabledTypes) != 2 {
+ t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
+ }
+ if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
+ t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
+ }
+ })
+
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
@@ -204,3 +228,210 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
+
+func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"alipay", "wxpay", "stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceOfficialAlipay: true,
+ VisibleMethodSourceOfficialWechat: false,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"alipay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceEasyPayAlipay: true,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"stripe", "alipay"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
+ t.Parallel()
+
+ instances := []*dbent.PaymentProviderInstance{
+ {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
+ {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
+ {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
+ }
+
+ got := buildVisibleMethodSourceAvailability(instances)
+ if !got[VisibleMethodSourceOfficialAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
+ }
+ if !got[VisibleMethodSourceEasyPayAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
+ }
+ if !got[VisibleMethodSourceOfficialWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
+ }
+ if !got[VisibleMethodSourceEasyPayWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
+ }
+}
+
+func TestGetPaymentConfigKeepsStoredEnabledTypes(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
+ },
+ },
+ }
+
+ cfg, err := svc.GetPaymentConfig(ctx)
+ if err != nil {
+ t.Fatalf("GetPaymentConfig returned error: %v", err)
+ }
+
+ want := []string{payment.TypeAlipay, payment.TypeWxpay, payment.TypeStripe}
+ if len(cfg.EnabledTypes) != len(want) {
+ t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
+ }
+ for i := range want {
+ if cfg.EnabledTypes[i] != want[i] {
+ t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
+ }
+ }
+}
+
+func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ dbName := fmt.Sprintf(
+ "file:%s?mode=memory&cache=shared",
+ strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()),
+ )
+ db, err := sql.Open("sqlite", dbName)
+ if err != nil {
+ t.Fatalf("open sqlite: %v", err)
+ }
+ t.Cleanup(func() { _ = db.Close() })
+
+ if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
+ t.Fatalf("enable foreign keys: %v", err)
+ }
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+type paymentConfigSettingRepoStub struct {
+ values map[string]string
+ updates map[string]string
+}
+
+func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
+ return nil, nil
+}
+func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ return s.values[key], nil
+}
+func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
+func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ out[key] = s.values[key]
+ }
+ return out, nil
+}
+func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error {
+ s.updates = make(map[string]string, len(values))
+ for key, value := range values {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ return s.values, nil
+}
+func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
+
+func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) {
+ repo := &paymentConfigSettingRepoStub{values: map[string]string{}}
+ svc := &PaymentConfigService{settingRepo: repo}
+
+ alipayEnabled := true
+ wxpayEnabled := false
+ err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{
+ VisibleMethodAlipayEnabled: &alipayEnabled,
+ VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay),
+ VisibleMethodWxpayEnabled: &wxpayEnabled,
+ VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat),
+ })
+ if err != nil {
+ t.Fatalf("UpdatePaymentConfig returned error: %v", err)
+ }
+
+ if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" {
+ t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled])
+ }
+ if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay {
+ t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay)
+ }
+ if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" {
+ t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled])
+ }
+ if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat {
+ t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat)
+ }
+}
+
+func paymentConfigStrPtr(value string) *string {
+ return &value
+}
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 44818b37..71f1eb2f 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -25,37 +25,99 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
- // Fallback: try legacy format (sub2_N where N is DB ID)
- trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
- if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
- return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
+ // Fallback only for true legacy "sub2_N" DB-ID payloads when the
+ // current out_trade_no lookup genuinely did not find an order.
+ if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok {
+ return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
- return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
-func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
+func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) {
+ if !dbent.IsNotFound(lookupErr) {
+ return 0, false
+ }
+ orderID = strings.TrimSpace(orderID)
+ if !strings.HasPrefix(orderID, orderIDPrefix) {
+ return 0, false
+ }
+ trimmed := strings.TrimPrefix(orderID, orderIDPrefix)
+ if trimmed == "" || trimmed == orderID {
+ return 0, false
+ }
+ oid, err := strconv.ParseInt(trimmed, 10, 64)
+ if err != nil || oid <= 0 {
+ return 0, false
+ }
+ return oid, true
+}
+
+func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
return nil
}
- // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
- // Also skip if paid is NaN/Inf (malformed provider data).
- if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
- if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
- s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
- return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
- }
+ instanceProviderKey := ""
+ if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
+ instanceProviderKey = inst.ProviderKey
}
- // Use order's expected amount when provider didn't report one
- if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
- paid = o.PayAmount
+ expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
+ if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
+ "expectedProvider": expectedProviderKey,
+ "actualProvider": pk,
+ "tradeNo": tradeNo,
+ })
+ return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
+ }
+ if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
+ "detail": err.Error(),
+ "tradeNo": tradeNo,
+ })
+ return err
+ }
+ if !isValidProviderAmount(paid) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{
+ "expected": o.PayAmount,
+ "paid": paid,
+ "tradeNo": tradeNo,
+ })
+ return fmt.Errorf("invalid paid amount from provider: %v", paid)
+ }
+ if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
+ return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
+func isValidProviderAmount(amount float64) bool {
+ return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
+}
+
+func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ return validateProviderSnapshotMetadata(order, providerKey, metadata)
+}
+
+func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
+ if key := strings.TrimSpace(instanceProviderKey); key != "" {
+ return key
+ }
+ if key := strings.TrimSpace(orderProviderKey); key != "" {
+ return key
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(orderPaymentType)
+}
+
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
previousStatus := o.Status
now := time.Now()
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 625b0d9f..abdb59de 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -3,12 +3,39 @@
package service
import (
+ "context"
"errors"
+ "math"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
+type paymentFulfillmentTestProvider struct {
+ key string
+ supportedTypes []payment.PaymentType
+}
+
+func (p paymentFulfillmentTestProvider) Name() string { return p.key }
+func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key }
+func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType {
+ return p.supportedTypes
+}
+func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
// ---------------------------------------------------------------------------
// resolveRedeemAction — pure idempotency decision logic
// ---------------------------------------------------------------------------
@@ -161,3 +188,181 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
assert.True(t, unusedCode.CanUse())
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
}
+
+func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay),
+ )
+}
+
+func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeEasyPay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t,
+ payment.TypeWxpay,
+ expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_key": payment.TypeEasyPay,
+ },
+ }
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKeyForOrder(registry, order, ""),
+ )
+}
+
+func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "merchant_app_id": "wx-app-expected",
+ "merchant_id": "mch-expected",
+ "currency": "CNY",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-other",
+ "mchid": "mch-expected",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.ErrorContains(t, err, "wxpay appid mismatch")
+}
+
+func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "9",
+ "provider_key": payment.TypeWxpay,
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-runtime",
+ "mchid": "mch-runtime",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.NoError(t, err)
+}
+
+func TestParseLegacyPaymentOrderID(t *testing.T) {
+ t.Parallel()
+
+ oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{})
+ assert.True(t, ok)
+ assert.EqualValues(t, 42, oid)
+
+ _, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{})
+ assert.False(t, ok)
+
+ _, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
+ assert.False(t, ok)
+}
+
+func TestIsValidProviderAmount(t *testing.T) {
+ t.Parallel()
+
+ assert.True(t, isValidProviderAmount(0.01))
+ assert.False(t, isValidProviderAmount(0))
+ assert.False(t, isValidProviderAmount(-1))
+ assert.False(t, isValidProviderAmount(math.NaN()))
+ assert.False(t, isValidProviderAmount(math.Inf(1)))
+}
+
+func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_app_id": "alipay-app-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{
+ "app_id": "alipay-app-other",
+ })
+ assert.ErrorContains(t, err, "alipay app_id mismatch")
+}
+
+func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_id": "pid-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{
+ "pid": "pid-other",
+ })
+ assert.ErrorContains(t, err, "easypay pid mismatch")
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 128416e4..15d4509d 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -2,9 +2,11 @@ package service
import (
"context"
+ "errors"
"fmt"
"log/slog"
"math"
+ "net/url"
"strconv"
"strings"
"time"
@@ -22,6 +24,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if req.OrderType == "" {
req.OrderType = payment.OrderTypeBalance
}
+ if normalized := NormalizeVisibleMethod(req.PaymentType); normalized != "" {
+ req.PaymentType = normalized
+ }
cfg, err := s.configService.GetPaymentConfig(ctx)
if err != nil {
return nil, fmt.Errorf("get payment config: %w", err)
@@ -54,11 +59,25 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
feeRate := cfg.RechargeFeeRate
payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
- order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount)
+ sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
if err != nil {
return nil, err
}
- resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan)
+ if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
+ return nil, err
+ }
+ oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
+ if err != nil {
+ return nil, err
+ }
+ if oauthResp != nil {
+ return oauthResp, nil
+ }
+ order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan, sel)
if err != nil {
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetStatus(OrderStatusFailed).
@@ -103,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return plan, nil
}
-func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
+func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
@@ -120,6 +139,17 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
+ outTradeNo, err := s.allocateOutTradeNo(ctx, tx)
+ if err != nil {
+ return nil, err
+ }
+ providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
+ selectedInstanceID := ""
+ selectedProviderKey := ""
+ if sel != nil {
+ selectedInstanceID = strings.TrimSpace(sel.InstanceID)
+ selectedProviderKey = strings.TrimSpace(sel.ProviderKey)
+ }
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
@@ -129,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
- SetOutTradeNo(generateOutTradeNo()).
+ SetOutTradeNo(outTradeNo).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
@@ -140,6 +170,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
+ if selectedInstanceID != "" {
+ b.SetProviderInstanceID(selectedInstanceID)
+ }
+ if selectedProviderKey != "" {
+ b.SetProviderKey(selectedProviderKey)
+ }
+ if providerSnapshot != nil {
+ b.SetProviderSnapshot(providerSnapshot)
+ }
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
@@ -158,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return order, nil
}
+func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) {
+ const maxAttempts = 5
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ candidate := generateOutTradeNo()
+ exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx)
+ if err != nil {
+ return "", fmt.Errorf("check out_trade_no uniqueness: %w", err)
+ }
+ if !exists {
+ return candidate, nil
+ }
+ }
+ return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts)
+}
+
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
@@ -167,12 +221,71 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return fmt.Errorf("count pending orders: %w", err)
}
if c >= max {
- return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)).
+ return infraerrors.TooManyRequests("TOO_MANY_PENDING", "too_many_pending").
WithMetadata(map[string]string{"max": strconv.Itoa(max)})
}
return nil
}
+func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
+ if sel == nil {
+ return nil
+ }
+
+ snapshot := map[string]any{}
+ snapshot["schema_version"] = 2
+
+ instanceID := strings.TrimSpace(sel.InstanceID)
+ if instanceID != "" {
+ snapshot["provider_instance_id"] = instanceID
+ }
+
+ providerKey := strings.TrimSpace(sel.ProviderKey)
+ if providerKey != "" {
+ snapshot["provider_key"] = providerKey
+ }
+
+ paymentMode := strings.TrimSpace(sel.PaymentMode)
+ if paymentMode != "" {
+ snapshot["payment_mode"] = paymentMode
+ }
+
+ if providerKey == payment.TypeWxpay {
+ if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ snapshot["currency"] = "CNY"
+ }
+ if providerKey == payment.TypeAlipay {
+ if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ }
+ if providerKey == payment.TypeEasyPay {
+ if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ }
+
+ if len(snapshot) == 1 {
+ return nil
+ }
+ return snapshot
+}
+
+func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
+ if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
+ return ""
+ }
+ if strings.TrimSpace(req.OpenID) != "" {
+ return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
+ }
+ return strings.TrimSpace(sel.Config["appId"])
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
@@ -191,33 +304,127 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
used += o.Amount
}
if used+amount > limit {
- return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used)))
+ return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily_limit_exceeded").
+ WithMetadata(map[string]string{"remaining": fmt.Sprintf("%.2f", math.Max(0, limit-used))})
}
return nil
}
-func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
- // Select an instance across all providers that support the requested payment type.
- // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
- sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
+func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
+ selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
if err != nil {
- return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
+ return nil, err
+ }
+ sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
+ if err != nil {
+ return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured").
+ WithMetadata(map[string]string{"payment_type": req.PaymentType})
}
if sel == nil {
- return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
+ return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance")
}
+ return sel, nil
+}
+
+func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
+ if !requestNeedsWeChatJSAPICompatibility(req) {
+ return ctx, nil
+ }
+ if !s.usesOfficialWxpayVisibleMethod(ctx) {
+ return ctx, nil
+ }
+ expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
+}
+
+func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
+ if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return false
+ }
+ return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
+}
+
+func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
+ if s == nil || s.configService == nil {
+ return false
+ }
+ inst, err := s.configService.resolveEnabledVisibleMethodInstance(ctx, payment.TypeWxpay)
+ if err != nil {
+ return false
+ }
+ if inst == nil {
+ return false
+ }
+ return inst.ProviderKey == payment.TypeWxpay
+}
+
+func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
- return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
+ slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
+ // If the provider returned a structured ApplicationError (e.g. WXPAY_CONFIG_MISSING_KEY),
+ // pass it through with provider context added to metadata. Otherwise wrap as PAYMENT_PROVIDER_MISCONFIGURED.
+ if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
+ md := map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID}
+ for k, v := range appErr.Metadata {
+ md[k] = v
+ }
+ return nil, appErr.WithMetadata(md)
+ }
+ return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
+ WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
- pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
+ canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
+ if err != nil {
+ return nil, err
+ }
+ resumeToken := ""
+ if resume := s.paymentResume(); resume != nil {
+ if canonicalReturnURL != "" && resume.isSigningConfigured() {
+ resumeToken, err = resume.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: order.UserID,
+ ProviderInstanceID: sel.InstanceID,
+ ProviderKey: sel.ProviderKey,
+ PaymentType: req.PaymentType,
+ CanonicalReturnURL: canonicalReturnURL,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("create payment resume token: %w", err)
+ }
+ }
+ }
+ providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, outTradeNo, resumeToken)
+ if err != nil {
+ return nil, err
+ }
+ providerReq := buildProviderCreatePaymentRequest(CreateOrderRequest{
+ PaymentType: req.PaymentType,
+ OpenID: req.OpenID,
+ ClientIP: req.ClientIP,
+ IsMobile: req.IsMobile,
+ ReturnURL: providerReturnURL,
+ }, sel, outTradeNo, payAmountStr, subject)
+ pr, err := prov.CreatePayment(ctx, providerReq)
if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
- return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
+ if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
+ return nil, appErr
+ }
+ return nil, classifyCreatePaymentError(req, sel.ProviderKey, err)
}
- _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
+ _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).
+ SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).
+ SetNillablePayURL(psNilIfEmpty(pr.PayURL)).
+ SetNillableQrCode(psNilIfEmpty(pr.QRCode)).
+ SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).
+ SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)).
+ Save(ctx)
if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err)
}
@@ -227,8 +434,36 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
"payAmount": order.PayAmount,
"paymentType": req.PaymentType,
"orderType": req.OrderType,
+ "paymentSource": NormalizePaymentSource(req.PaymentSource),
})
- return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
+ resultType := pr.ResultType
+ if resultType == "" {
+ resultType = payment.CreatePaymentResultOrderCreated
+ }
+ resp := buildCreateOrderResponse(order, req, payAmount, sel, pr, resultType)
+ resp.ResumeToken = resumeToken
+ return resp, nil
+}
+
+func buildProviderCreatePaymentRequest(req CreateOrderRequest, sel *payment.InstanceSelection, orderID, amount, subject string) payment.CreatePaymentRequest {
+ return payment.CreatePaymentRequest{
+ OrderID: orderID,
+ Amount: amount,
+ PaymentType: req.PaymentType,
+ Subject: subject,
+ ReturnURL: req.ReturnURL,
+ OpenID: strings.TrimSpace(req.OpenID),
+ ClientIP: req.ClientIP,
+ IsMobile: req.IsMobile,
+ InstanceSubMethods: selectedInstanceSupportedTypes(sel),
+ }
+}
+
+func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
+ if sel == nil {
+ return ""
+ }
+ return sel.SupportedTypes
}
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
@@ -247,6 +482,193 @@ func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limit
return "Sub2API " + amountStr + " CNY"
}
+func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
+ return s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, amount, payAmount, feeRate, nil)
+}
+
+func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponseForSelection(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
+ if sel != nil && sel.ProviderKey != "" && sel.ProviderKey != payment.TypeWxpay {
+ return nil, nil
+ }
+ if strings.TrimSpace(req.OpenID) != "" || !req.IsWeChatBrowser || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return nil, nil
+ }
+ return s.buildWeChatOAuthRequiredResponse(ctx, req, amount, payAmount, feeRate)
+}
+
+func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
+ appID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := s.paymentResume().ensureSigningKey(); err != nil {
+ return nil, err
+ }
+
+ authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
+ if err != nil {
+ return nil, err
+ }
+
+ return &CreateOrderResponse{
+ Amount: amount,
+ PayAmount: payAmount,
+ FeeRate: feeRate,
+ ResultType: payment.CreatePaymentResultOAuthRequired,
+ PaymentType: req.PaymentType,
+ OAuth: &payment.WechatOAuthInfo{
+ AuthorizeURL: authorizeURL,
+ AppID: appID,
+ Scope: "snsapi_base",
+ RedirectURL: "/auth/wechat/payment/callback",
+ },
+ }, nil
+}
+
+func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context, req CreateOrderRequest, sel *payment.InstanceSelection) error {
+ if !requiresWeChatJSAPICompatibleSelection(req, sel) {
+ return nil
+ }
+ expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return err
+ }
+ selectedAppID := provider.ResolveWxpayJSAPIAppID(sel.Config)
+ if selectedAppID == "" || selectedAppID != expectedAppID {
+ return infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "selected payment instance is not compatible with the current WeChat OAuth app")
+ }
+ return nil
+}
+
+func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
+ if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return false
+ }
+ return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
+}
+
+func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) {
+ if s == nil || s.configService == nil || s.configService.settingRepo == nil {
+ return "", "", infraerrors.ServiceUnavailable(
+ "WECHAT_PAYMENT_MP_NOT_CONFIGURED",
+ "wechat in-app payment requires a complete WeChat MP OAuth credential",
+ )
+ }
+ cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx)
+ appID := strings.TrimSpace(cfg.AppIDForMode("mp"))
+ appSecret := strings.TrimSpace(cfg.AppSecretForMode("mp"))
+ if err != nil || !cfg.SupportsMode("mp") || appID == "" || appSecret == "" {
+ return "", "", infraerrors.ServiceUnavailable(
+ "WECHAT_PAYMENT_MP_NOT_CONFIGURED",
+ "wechat in-app payment requires a complete WeChat MP OAuth credential",
+ )
+ }
+ return appID, appSecret, nil
+}
+
+func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error {
+ if err == nil {
+ return nil
+ }
+ if providerKey == payment.TypeWxpay &&
+ payment.GetBasePaymentType(req.PaymentType) == payment.TypeWxpay &&
+ strings.Contains(err.Error(), "wxpay h5 payments are not authorized for this merchant") {
+ return infraerrors.ServiceUnavailable(
+ "WECHAT_H5_NOT_AUTHORIZED",
+ "wechat h5 payment is not available for this merchant",
+ ).WithMetadata(map[string]string{
+ "action": "open_in_wechat_or_scan_qr",
+ })
+ }
+ return infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
+}
+
+func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest, payAmount float64, sel *payment.InstanceSelection, pr *payment.CreatePaymentResponse, resultType payment.CreatePaymentResultType) *CreateOrderResponse {
+ return &CreateOrderResponse{
+ OrderID: order.ID,
+ Amount: order.Amount,
+ PayAmount: payAmount,
+ FeeRate: order.FeeRate,
+ Status: OrderStatusPending,
+ ResultType: resultType,
+ PaymentType: req.PaymentType,
+ OutTradeNo: order.OutTradeNo,
+ PayURL: pr.PayURL,
+ QRCode: pr.QRCode,
+ ClientSecret: pr.ClientSecret,
+ OAuth: pr.OAuth,
+ JSAPI: pr.JSAPI,
+ JSAPIPayload: pr.JSAPI,
+ ExpiresAt: order.ExpiresAt,
+ PaymentMode: sel.PaymentMode,
+ }
+}
+
+func buildWeChatPaymentOAuthStartURL(req CreateOrderRequest, scope string) (string, error) {
+ u, err := url.Parse("/api/v1/auth/oauth/wechat/payment/start")
+ if err != nil {
+ return "", fmt.Errorf("build wechat payment oauth start url: %w", err)
+ }
+ q := u.Query()
+ q.Set("payment_type", strings.TrimSpace(req.PaymentType))
+ if req.Amount > 0 {
+ q.Set("amount", strconv.FormatFloat(req.Amount, 'f', -1, 64))
+ }
+ if orderType := strings.TrimSpace(req.OrderType); orderType != "" {
+ q.Set("order_type", orderType)
+ }
+ if req.PlanID > 0 {
+ q.Set("plan_id", strconv.FormatInt(req.PlanID, 10))
+ }
+ if scope = strings.TrimSpace(scope); scope != "" {
+ q.Set("scope", scope)
+ }
+ if redirectTo := paymentRedirectPathFromURL(req.SrcURL); redirectTo != "" {
+ q.Set("redirect", redirectTo)
+ }
+ u.RawQuery = q.Encode()
+ return u.String(), nil
+}
+
+func paymentRedirectPathFromURL(rawURL string) string {
+ rawURL = strings.TrimSpace(rawURL)
+ if rawURL == "" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(rawURL, "/") && !strings.HasPrefix(rawURL, "//") {
+ return normalizePaymentRedirectPath(rawURL)
+ }
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ return "/purchase"
+ }
+ path := strings.TrimSpace(u.EscapedPath())
+ if path == "" {
+ path = strings.TrimSpace(u.Path)
+ }
+ if path == "" || !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
+ return "/purchase"
+ }
+ if strings.TrimSpace(u.RawQuery) != "" {
+ path += "?" + u.RawQuery
+ }
+ return normalizePaymentRedirectPath(path)
+}
+
+func normalizePaymentRedirectPath(path string) string {
+ path = strings.TrimSpace(path)
+ if path == "" {
+ return "/purchase"
+ }
+ if path == "/payment" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(path, "/payment?") {
+ return "/purchase" + strings.TrimPrefix(path, "/payment")
+ }
+ return path
+}
+
// --- Order Queries ---
func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) {
diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go
new file mode 100644
index 00000000..8c5e4fc0
--- /dev/null
+++ b/backend/internal/service/payment_order_jsapi_test.go
@@ -0,0 +1,98 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{entClient: client},
+ }
+
+ if !svc.usesOfficialWxpayVisibleMethod(ctx) {
+ t.Fatal("expected official wxpay visible method to be detected from enabled provider instance")
+ }
+}
+
+func TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
+ tests := []struct {
+ name string
+ source string
+ wantOfficial bool
+ }{
+ {
+ name: "official source selected",
+ source: VisibleMethodSourceOfficialWechat,
+ wantOfficial: true,
+ },
+ {
+ name: "easypay source selected",
+ source: VisibleMethodSourceEasyPayWechat,
+ wantOfficial: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay wxpay instance: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodWxpaySource: tt.source,
+ },
+ },
+ },
+ }
+
+ if got := svc.usesOfficialWxpayVisibleMethod(ctx); got != tt.wantOfficial {
+ t.Fatalf("usesOfficialWxpayVisibleMethod() = %v, want %v", got, tt.wantOfficial)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 80147180..b627ced4 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -139,34 +140,123 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil {
return ""
}
- // Use OutTradeNo as fallback when PaymentTradeNo is empty
- // (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
- tradeNo := o.PaymentTradeNo
- if tradeNo == "" {
- tradeNo = o.OutTradeNo
+ queryRef := paymentOrderQueryReference(o, prov)
+ if queryRef == "" {
+ return ""
}
- resp, err := prov.QueryOrder(ctx, tradeNo)
+ resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
- if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ if !isValidProviderAmount(resp.Amount) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{
+ "expected": o.PayAmount,
+ "paid": resp.Amount,
+ "tradeNo": resp.TradeNo,
+ "queryRef": queryRef,
+ })
+ slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount)
+ retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef)
+ if !retryOK {
+ return ""
+ }
+ resp = retriedResp
+ }
+ notificationTradeNo := o.PaymentTradeNo
+ if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
+ if _, updateErr := s.entClient.PaymentOrder.Update().
+ Where(paymentorder.IDEQ(o.ID)).
+ SetPaymentTradeNo(upstreamTradeNo).
+ Save(ctx); updateErr != nil {
+ slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr)
+ } else {
+ o.PaymentTradeNo = upstreamTradeNo
+ }
+ notificationTradeNo = upstreamTradeNo
+ }
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return checkPaidResultAlreadyPaid
}
if cp, ok := prov.(payment.CancelableProvider); ok {
- _ = cp.CancelPayment(ctx, tradeNo)
+ _ = cp.CancelPayment(ctx, queryRef)
}
return ""
}
+func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) {
+ if prov == nil || strings.TrimSpace(queryRef) == "" {
+ return nil, false
+ }
+ resp, err := prov.QueryOrder(ctx, queryRef)
+ if err != nil {
+ slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err)
+ return nil, false
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) {
+ return nil, false
+ }
+ return resp, true
+}
+
+func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
+ if order == nil {
+ return ""
+ }
+
+ providerKey := ""
+ if prov != nil {
+ providerKey = strings.TrimSpace(prov.ProviderKey())
+ }
+ if providerKey == "" {
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
+ providerKey = strings.TrimSpace(snapshot.ProviderKey)
+ }
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(order.PaymentType)
+ }
+
+ switch payment.GetBasePaymentType(providerKey) {
+ case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
+ return strings.TrimSpace(order.OutTradeNo)
+ default:
+ if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
+ return tradeNo
+ }
+ return strings.TrimSpace(order.OutTradeNo)
+ }
+}
+
+func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
+ upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
+ if upstreamTradeNo == "" {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
+ return false
+ }
+ return true
+}
+
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
+ outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
+ if err != nil {
+ return nil, err
+ }
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
@@ -190,25 +280,42 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
-// VerifyOrderPublic verifies payment status without user authentication.
-// Used by the payment result page when the user's session has expired.
+// VerifyOrderPublic returns the currently persisted public order state without
+// triggering any upstream reconciliation. Signed resume-token recovery is the
+// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
+ outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
+ if err != nil {
+ return nil, err
+ }
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
- if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
- result := s.checkPaid(ctx, o)
- if result == checkPaidResultAlreadyPaid {
- o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
- if err != nil {
- return nil, fmt.Errorf("reload order: %w", err)
- }
+ return o, nil
+}
+
+func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
+ outTradeNo := strings.TrimSpace(raw)
+ if outTradeNo == "" {
+ return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
+ }
+ if len(outTradeNo) > 64 {
+ return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
+ }
+ for _, ch := range outTradeNo {
+ switch {
+ case ch >= 'a' && ch <= 'z':
+ case ch >= 'A' && ch <= 'Z':
+ case ch >= '0' && ch <= '9':
+ case ch == '_' || ch == '-':
+ default:
+ return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
- return o, nil
+ return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
@@ -236,22 +343,79 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error)
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
- instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
- if err == nil {
- cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
- if err == nil {
- providerKey := s.registry.GetProviderKey(o.PaymentType)
- if providerKey == "" {
- providerKey = o.PaymentType
- }
- p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
- if err == nil {
- return p, nil
- }
- }
- }
+ inst, err := s.getOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst != nil {
+ return s.createProviderFromInstance(ctx, inst)
+ }
+ if !paymentOrderAllowsRegistryFallback(o) {
+ return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID)
+ }
+ providerKey := paymentOrderFallbackProviderKey(s.registry, o)
+ if providerKey == "" {
+ return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
+
+func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool {
+ if order == nil {
+ return false
+ }
+ if psOrderProviderSnapshot(order) != nil {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" {
+ return false
+ }
+ return true
+}
+
+func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string {
+ if order == nil {
+ return ""
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType)))
+}
+
+func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
+ if inst == nil {
+ return nil, fmt.Errorf("payment provider instance is missing")
+ }
+
+ cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance config: %w", err)
+ }
+ if inst.PaymentMode != "" {
+ cfg["paymentMode"] = inst.PaymentMode
+ }
+
+ instID := strconv.FormatInt(int64(inst.ID), 10)
+ prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
+ if err != nil {
+ return nil, fmt.Errorf("create provider from instance: %w", err)
+ }
+ return prov, nil
+}
+
+func psStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
new file mode 100644
index 00000000..8dfd2e7e
--- /dev/null
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -0,0 +1,575 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type paymentOrderLifecycleQueryProvider struct {
+ lastQueryTradeNo string
+ queryCalls int
+ responses []*payment.QueryOrderResponse
+ resp *payment.QueryOrderResponse
+}
+
+type paymentOrderLifecycleRedeemRepo struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Name() string {
+ return "payment-order-lifecycle-query-provider"
+}
+
+func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ p.lastQueryTradeNo = tradeNo
+ p.queryCalls++
+ if len(p.responses) > 0 {
+ resp := p.responses[0]
+ if len(p.responses) > 1 {
+ p.responses = p.responses[1:]
+ }
+ return resp, nil
+ }
+ return p.resp, nil
+}
+
+func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) {
+ for _, code := range r.codesByCode {
+ if code.ID != id {
+ continue
+ }
+ cloned := *code
+ return &cloned, nil
+ }
+ return nil, ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ redeemCode, ok := r.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range r.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ r.codesByCode[code] = redeemCode
+ r.useCalls = append(r.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected call")
+}
+
+func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_trade_no_missing").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-123",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, OrderStatusCompleted, got.Status)
+ require.Equal(t, "upstream-trade-123", got.PaymentTradeNo)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusCompleted, reloaded.Status)
+ require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo)
+
+ require.Equal(t, 88.0, userRepo.getByIDUser.Balance)
+ require.Len(t, redeemRepo.useCalls, 1)
+ require.Equal(t, int64(1), redeemRepo.useCalls[0].id)
+ require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
+}
+
+func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-retry@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-retry-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-UPSTREAM-RETRY").
+ SetOutTradeNo("sub2_checkpaid_retry_zero_amount").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ responses: []*payment.QueryOrderResponse{
+ {
+ TradeNo: "upstream-trade-zero",
+ Status: payment.ProviderStatusPaid,
+ Amount: 0,
+ },
+ {
+ TradeNo: "upstream-trade-retry",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2, provider.queryCalls)
+ require.Equal(t, OrderStatusCompleted, got.Status)
+ require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo)
+}
+
+func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-zero-amount@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-zero-amount-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-ZERO-AMOUNT").
+ SetOutTradeNo("sub2_checkpaid_zero_amount").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-zero",
+ Status: payment.ProviderStatusPaid,
+ Amount: 0,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, OrderStatusPending, got.Status)
+ require.Empty(t, got.PaymentTradeNo)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusPending, reloaded.Status)
+ require.Empty(t, reloaded.PaymentTradeNo)
+
+ require.Equal(t, 0.0, userRepo.getByIDUser.Balance)
+ require.Empty(t, redeemRepo.useCalls)
+}
+
+func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-existing-trade@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-existing-trade-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("upstream-trade-existing").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-existing",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
+}
+
+func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ }))
+
+ instanceID := "12"
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderInstanceID: &instanceID,
+ }))
+
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "12",
+ },
+ }))
+}
+
+func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ OutTradeNo: "sub2_out_trade_no",
+ PaymentTradeNo: "wx-transaction-id",
+ }
+
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{}))
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{
+ key: payment.TypeWxpay,
+ }))
+}
+
+func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
new file mode 100644
index 00000000..bb60f9e2
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -0,0 +1,205 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+type paymentOrderProviderSnapshot struct {
+ SchemaVersion int
+ ProviderInstanceID string
+ ProviderKey string
+ PaymentMode string
+ MerchantAppID string
+ MerchantID string
+ Currency string
+}
+
+func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
+ if order == nil || len(order.ProviderSnapshot) == 0 {
+ return nil
+ }
+
+ snapshot := &paymentOrderProviderSnapshot{
+ SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
+ ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
+ ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
+ PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
+ MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
+ MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
+ Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
+ }
+ if snapshot.SchemaVersion == 0 &&
+ snapshot.ProviderInstanceID == "" &&
+ snapshot.ProviderKey == "" &&
+ snapshot.PaymentMode == "" &&
+ snapshot.MerchantAppID == "" &&
+ snapshot.MerchantID == "" &&
+ snapshot.Currency == "" {
+ return nil
+ }
+ return snapshot
+}
+
+func psSnapshotStringValue(value any) string {
+ switch typed := value.(type) {
+ case string:
+ return strings.TrimSpace(typed)
+ default:
+ return ""
+ }
+}
+
+func psSnapshotIntValue(value any) int {
+ switch typed := value.(type) {
+ case int:
+ return typed
+ case int32:
+ return int(typed)
+ case int64:
+ return int(typed)
+ case float32:
+ return int(typed)
+ case float64:
+ return int(typed)
+ case string:
+ n, err := strconv.Atoi(strings.TrimSpace(typed))
+ if err == nil {
+ return n
+ }
+ }
+ return 0
+}
+
+func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || order == nil || snapshot == nil {
+ return nil, nil
+ }
+
+ snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
+ columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ if snapshotInstanceID == "" {
+ snapshotInstanceID = columnInstanceID
+ }
+ if snapshotInstanceID == "" {
+ return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
+ }
+ if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
+ return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
+ }
+
+ instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
+ }
+
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
+ }
+ return nil, err
+ }
+
+ if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
+ return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
+ }
+
+ return inst, nil
+}
+
+func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
+ if order == nil {
+ return strings.TrimSpace(instanceProviderKey)
+ }
+
+ orderProviderKey := psStringValue(order.ProviderKey)
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+
+ return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
+}
+
+func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ if order == nil || len(metadata) == 0 {
+ return nil
+ }
+
+ snapshot := psOrderProviderSnapshot(order)
+ if snapshot == nil {
+ return nil
+ }
+
+ switch strings.TrimSpace(providerKey) {
+ case payment.TypeWxpay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["appid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing appid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["mchid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing mchid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
+ actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing currency")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
+ return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
+ }
+ case payment.TypeAlipay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["app_id"])
+ if actual == "" {
+ return fmt.Errorf("alipay app_id missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ case payment.TypeEasyPay:
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["pid"])
+ if actual == "" {
+ return fmt.Errorf("easypay pid missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ }
+
+ return nil
+}
+
+func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string {
+ if prov == nil {
+ return nil
+ }
+ reporter, ok := prov.(payment.MerchantIdentityProvider)
+ if !ok {
+ return nil
+ }
+ return reporter.MerchantIdentityMetadata()
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
new file mode 100644
index 00000000..efa013b5
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -0,0 +1,172 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "strconv"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) {
+ t.Parallel()
+
+ sel := &payment.InstanceSelection{
+ InstanceID: "12",
+ ProviderKey: payment.TypeWxpay,
+ SupportedTypes: "wxpay,wxpay_direct",
+ PaymentMode: "popup",
+ Config: map[string]string{
+ "privateKey": "secret",
+ "apiV3Key": "secret-v3",
+ "appId": "wx-app-id",
+ },
+ }
+
+ snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
+ require.Equal(t, map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "12",
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "popup",
+ "merchant_app_id": "wx-app-id",
+ "currency": "CNY",
+ }, snapshot)
+ require.NotContains(t, snapshot, "config")
+ require.NotContains(t, snapshot, "privateKey")
+ require.NotContains(t, snapshot, "apiV3Key")
+ require.NotContains(t, snapshot, "supported_types")
+ require.NotContains(t, snapshot, "instance_name")
+ require.NotContains(t, snapshot, "merchant_id")
+}
+
+func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("snapshot@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instance, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Primary Alipay").
+ SetConfig(`{"secretKey":"do-not-copy"}`).
+ SetSupportedTypes("alipay,alipay_direct").
+ SetPaymentMode("redirect").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{entClient: client}
+ order, err := svc.createOrderInTx(
+ ctx,
+ CreateOrderRequest{
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ OrderType: payment.OrderTypeBalance,
+ ClientIP: "127.0.0.1",
+ SrcHost: "app.example.com",
+ },
+ &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ },
+ nil,
+ &PaymentConfig{
+ MaxPendingOrders: 3,
+ OrderTimeoutMin: 30,
+ },
+ 88,
+ 88,
+ 0,
+ 88,
+ &payment.InstanceSelection{
+ InstanceID: strconv.FormatInt(instance.ID, 10),
+ ProviderKey: payment.TypeAlipay,
+ SupportedTypes: "alipay,alipay_direct",
+ PaymentMode: "redirect",
+ Config: map[string]string{
+ "secretKey": "do-not-copy",
+ },
+ },
+ )
+ require.NoError(t, err)
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
+ require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
+ require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
+ require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
+ require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
+ require.NotContains(t, order.ProviderSnapshot, "config")
+ require.NotContains(t, order.ProviderSnapshot, "secretKey")
+ require.NotContains(t, order.ProviderSnapshot, "supported_types")
+ require.NotContains(t, order.ProviderSnapshot, "instance_name")
+}
+
+func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "88",
+ ProviderKey: payment.TypeWxpay,
+ Config: map[string]string{
+ "appId": "wx-open-app",
+ "mpAppId": "wx-mp-app",
+ "mchId": "mch-88",
+ },
+ PaymentMode: "jsapi",
+ }, CreateOrderRequest{OpenID: "openid-123"})
+
+ require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
+ require.Equal(t, "mch-88", snapshot["merchant_id"])
+ require.Equal(t, "CNY", snapshot["currency"])
+}
+
+func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "21",
+ ProviderKey: payment.TypeAlipay,
+ Config: map[string]string{
+ "appId": "alipay-app-21",
+ "privateKey": "secret",
+ },
+ PaymentMode: "redirect",
+ }, CreateOrderRequest{})
+
+ require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"])
+ require.NotContains(t, snapshot, "privateKey")
+}
+
+func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "66",
+ ProviderKey: payment.TypeEasyPay,
+ Config: map[string]string{
+ "pid": "easypay-merchant-66",
+ "pkey": "secret",
+ },
+ PaymentMode: "popup",
+ }, CreateOrderRequest{PaymentType: payment.TypeAlipay})
+
+ require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"])
+ require.NotContains(t, snapshot, "pkey")
+}
+
+func valueOrEmpty(v *string) string {
+ if v == nil {
+ return ""
+ }
+ return *v
+}
diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go
new file mode 100644
index 00000000..2d7412e0
--- /dev/null
+++ b/backend/internal/service/payment_order_result_test.go
@@ -0,0 +1,276 @@
+package service
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func TestBuildCreateOrderResponseDefaultsToOrderCreated(t *testing.T) {
+ t.Parallel()
+
+ expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC)
+ resp := buildCreateOrderResponse(
+ &dbent.PaymentOrder{
+ ID: 42,
+ Amount: 12.34,
+ FeeRate: 0.03,
+ ExpiresAt: expiresAt,
+ OutTradeNo: "sub2_42",
+ },
+ CreateOrderRequest{PaymentType: payment.TypeWxpay},
+ 12.71,
+ &payment.InstanceSelection{PaymentMode: "qrcode"},
+ &payment.CreatePaymentResponse{
+ TradeNo: "sub2_42",
+ QRCode: "weixin://wxpay/bizpayurl?pr=test",
+ },
+ payment.CreatePaymentResultOrderCreated,
+ )
+
+ if resp.ResultType != payment.CreatePaymentResultOrderCreated {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOrderCreated)
+ }
+ if resp.OutTradeNo != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", resp.OutTradeNo, "sub2_42")
+ }
+ if resp.QRCode != "weixin://wxpay/bizpayurl?pr=test" {
+ t.Fatalf("qr_code = %q, want %q", resp.QRCode, "weixin://wxpay/bizpayurl?pr=test")
+ }
+ if resp.JSAPI != nil || resp.JSAPIPayload != nil {
+ t.Fatal("order_created response should not include jsapi payload")
+ }
+ if !resp.ExpiresAt.Equal(expiresAt) {
+ t.Fatalf("expires_at = %v, want %v", resp.ExpiresAt, expiresAt)
+ }
+}
+
+func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
+ t.Parallel()
+
+ jsapiPayload := &payment.WechatJSAPIPayload{
+ AppID: "wx123",
+ TimeStamp: "1712345678",
+ NonceStr: "nonce-123",
+ Package: "prepay_id=wx123",
+ SignType: "RSA",
+ PaySign: "signed-payload",
+ }
+ resp := buildCreateOrderResponse(
+ &dbent.PaymentOrder{
+ ID: 88,
+ Amount: 66.88,
+ FeeRate: 0.01,
+ ExpiresAt: time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC),
+ OutTradeNo: "sub2_88",
+ },
+ CreateOrderRequest{PaymentType: payment.TypeWxpay},
+ 67.55,
+ &payment.InstanceSelection{PaymentMode: "popup"},
+ &payment.CreatePaymentResponse{
+ TradeNo: "sub2_88",
+ ResultType: payment.CreatePaymentResultJSAPIReady,
+ JSAPI: jsapiPayload,
+ },
+ payment.CreatePaymentResultJSAPIReady,
+ )
+
+ if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
+ }
+ if resp.JSAPI == nil || resp.JSAPIPayload == nil {
+ t.Fatal("expected jsapi payload aliases to be populated")
+ }
+ if resp.JSAPI != jsapiPayload || resp.JSAPIPayload != jsapiPayload {
+ t.Fatal("expected jsapi aliases to preserve the original pointer")
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ svc := newWeChatPaymentOAuthTestService(map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp == nil {
+ t.Fatal("expected oauth_required response, got nil")
+ }
+ if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
+ }
+ if resp.OAuth == nil {
+ t.Fatal("expected oauth payload, got nil")
+ }
+ if resp.OAuth.AppID != "wx123456" {
+ t.Fatalf("appid = %q, want %q", resp.OAuth.AppID, "wx123456")
+ }
+ if resp.OAuth.Scope != "snsapi_base" {
+ t.Fatalf("scope = %q, want %q", resp.OAuth.Scope, "snsapi_base")
+ }
+ if resp.OAuth.RedirectURL != "/auth/wechat/payment/callback" {
+ t.Fatalf("redirect_url = %q, want %q", resp.OAuth.RedirectURL, "/auth/wechat/payment/callback")
+ }
+ if resp.OAuth.AuthorizeURL != "/api/v1/auth/oauth/wechat/payment/start?amount=12.5&order_type=balance&payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat&scope=snsapi_base" {
+ t.Fatalf("authorize_url = %q", resp.OAuth.AuthorizeURL)
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) {
+ t.Parallel()
+
+ svc := newWeChatPaymentOAuthTestService(nil)
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ appErr := infraerrors.FromError(err)
+ if appErr.Reason != "WECHAT_PAYMENT_MP_NOT_CONFIGURED" {
+ t.Fatalf("reason = %q, want %q", appErr.Reason, "WECHAT_PAYMENT_MP_NOT_CONFIGURED")
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ }},
+ // Intentionally missing payment resume signing key.
+ encryptionKey: nil,
+ },
+ }
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ appErr := infraerrors.FromError(err)
+ if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" {
+ t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED")
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ }},
+ // Legacy stable signing key remains available for no-config upgrade compatibility.
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ },
+ }
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ if resp == nil {
+ t.Fatal("expected oauth-required response, got nil")
+ }
+ if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
+ }
+ if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
+ t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
+ svc := newWeChatPaymentOAuthTestService(map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03, &payment.InstanceSelection{
+ ProviderKey: payment.TypeEasyPay,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+}
+
+func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
+ return &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: values},
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ },
+ }
+}
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index c5bda763..7521878c 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
@@ -19,18 +20,133 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
-// Returns nil, nil for legacy orders without provider_instance_id.
+// For legacy orders without provider_instance_id, it resolves only when the
+// historical instance is uniquely identifiable from the stored order fields.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
- if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
+ if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
- instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
+
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
+// getRefundOrderProviderInstance resolves the provider instance for refund paths.
+// Refunds must be pinned to an explicit historical binding, so legacy
+// "best-effort" provider guessing is intentionally not allowed here.
+func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || o == nil {
+ return nil, nil
+ }
+
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return nil, nil
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr)
+ }
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr)
+ }
+ return nil, err
+ }
+ return inst, nil
+}
+
+func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
+ providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
+ if providerKey != "" {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.ProviderKeyEQ(providerKey)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
+ if len(matched) == 1 {
+ return matched[0], nil
+ }
+ return nil, nil
+ }
+
+ if paymentType == "" {
+ return nil, nil
+ }
+
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
+ if len(matched) == 1 {
+ return matched[0], nil
+ }
+ return nil, nil
+}
+
+func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance {
+ if len(instances) == 0 {
+ return nil
+ }
+ if strings.TrimSpace(orderPaymentType) == "" {
+ return instances
+ }
+ var matched []*dbent.PaymentProviderInstance
+ for _, inst := range instances {
+ if psLegacyOrderMatchesInstance(orderPaymentType, inst) {
+ matched = append(matched, inst)
+ }
+ }
+ return matched
+}
+
+func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool {
+ if inst == nil {
+ return false
+ }
+
+ baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType))
+ instanceProviderKey := strings.TrimSpace(inst.ProviderKey)
+ if baseType == "" {
+ return false
+ }
+
+ if baseType == payment.TypeStripe {
+ return instanceProviderKey == payment.TypeStripe
+ }
+ if instanceProviderKey == payment.TypeStripe {
+ return false
+ }
+ if instanceProviderKey == baseType {
+ return true
+ }
+ return payment.InstanceSupportsType(inst.SupportedTypes, baseType)
+}
+
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
@@ -72,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
- inst, err := s.getOrderProviderInstance(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
@@ -92,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
- inst, instErr := s.getOrderProviderInstance(ctx, o)
+ inst, instErr := s.getRefundOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
@@ -217,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
+ if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil {
+ s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{
+ "detail": err.Error(),
+ })
+ return err
+ }
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
@@ -229,7 +351,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- return s.getOrderProvider(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, err
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID)
+ }
+ return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go
new file mode 100644
index 00000000..ca5b62cb
--- /dev/null
+++ b/backend/internal/service/payment_refund_test.go
@@ -0,0 +1,186 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ _, err = svc.validateRefundRequest(ctx, order.ID, user.ID)
+ require.Error(t, err)
+ require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err))
+}
+
+func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy-admin@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-admin-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-admin-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(188).
+ SetPayAmount(188).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_admin_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-admin-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false)
+ require.Nil(t, plan)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
+}
+
+func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-snapshot-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-snapshot-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-mismatch-instance").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "runtime-alipay-app",
+ "privateKey": "runtime-private-key",
+ })).
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ instID := strconv.FormatInt(inst.ID, 10)
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER").
+ SetOutTradeNo("sub2_refund_snapshot_mismatch_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-refund-snapshot-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instID).
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": instID,
+ "provider_key": payment.TypeAlipay,
+ "merchant_app_id": "expected-alipay-app",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ err = svc.gwRefund(ctx, &RefundPlan{
+ OrderID: order.ID,
+ Order: order,
+ RefundAmount: order.Amount,
+ GatewayAmount: order.Amount,
+ Reason: "snapshot mismatch",
+ })
+ require.ErrorContains(t, err, "alipay app_id mismatch")
+}
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
new file mode 100644
index 00000000..1ff061e8
--- /dev/null
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -0,0 +1,67 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
+ claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token))
+ if err != nil {
+ return nil, err
+ }
+
+ order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
+ }
+ return nil, fmt.Errorf("get order by resume token: %w", err)
+ }
+ if claims.UserID > 0 && order.UserID != claims.UserID {
+ return nil, invalidResumeTokenMatchError()
+ }
+ snapshot := psOrderProviderSnapshot(order)
+ orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
+ if snapshot != nil {
+ if snapshot.ProviderInstanceID != "" {
+ orderProviderInstanceID = snapshot.ProviderInstanceID
+ }
+ if snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+ }
+ if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
+ return nil, invalidResumeTokenMatchError()
+ }
+ if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
+ return nil, invalidResumeTokenMatchError()
+ }
+ if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
+ return nil, invalidResumeTokenMatchError()
+ }
+ if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
+ result := s.checkPaid(ctx, order)
+ if result == checkPaidResultAlreadyPaid {
+ order, err = s.entClient.PaymentOrder.Get(ctx, order.ID)
+ if err != nil {
+ return nil, fmt.Errorf("reload order by resume token: %w", err)
+ }
+ }
+ }
+
+ return order, nil
+}
+
+func invalidResumeTokenMatchError() error {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
+}
+
+func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
+}
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
new file mode 100644
index 00000000..a7b5b737
--- /dev/null
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -0,0 +1,315 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+type paymentResumeLookupProvider struct {
+ queryCount int
+}
+
+func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" }
+
+func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ p.queryCount++
+ return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil
+}
+
+func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instanceID := "12"
+ providerKey := payment.TypeEasyPay
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-ORDER").
+ SetOutTradeNo("sub2_resume_lookup").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-1").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instanceID).
+ SetProviderKey(providerKey).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: instanceID,
+ ProviderKey: providerKey,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
+func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-MISMATCH").
+ SetOutTradeNo("sub2_resume_lookup_mismatch").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-2").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("12").
+ SetProviderKey(payment.TypeEasyPay).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "99",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ _, err = svc.GetPublicOrderByResumeToken(ctx, token)
+ require.Error(t, err)
+ require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
+}
+
+func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-snapshot-authority@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-snapshot-authority-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
+ SetOutTradeNo("sub2_resume_snapshot_authority").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-snapshot-authority").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("legacy-column-instance").
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "snapshot-instance",
+ "provider_key": payment.TypeEasyPay,
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "snapshot-instance",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
+func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-refresh@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-refresh-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-PENDING").
+ SetOutTradeNo("sub2_resume_lookup_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-pending").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentResumeLookupProvider{}
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ resumeService: resumeSvc,
+ providersLoaded: true,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+ require.Equal(t, 1, provider.queryCount)
+}
+
+func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("public-verify@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-verify-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("PUBLIC-VERIFY").
+ SetOutTradeNo("sub2_public_verify_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-verify").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentResumeLookupProvider{}
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+ require.Equal(t, 0, provider.queryCount)
+}
+
+func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
+ svc := &PaymentService{
+ entClient: newPaymentConfigServiceTestClient(t),
+ }
+
+ _, err := svc.VerifyOrderPublic(context.Background(), " ")
+ require.Error(t, err)
+ require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
+}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
new file mode 100644
index 00000000..9ae62fde
--- /dev/null
+++ b/backend/internal/service/payment_resume_service.go
@@ -0,0 +1,476 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const paymentResultReturnPath = "/payment/result"
+
+const (
+ PaymentSourceHostedRedirect = "hosted_redirect"
+ PaymentSourceWechatInAppResume = "wechat_in_app_resume"
+
+ SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
+ SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
+ SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
+ SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
+
+ VisibleMethodSourceOfficialAlipay = "official_alipay"
+ VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
+ VisibleMethodSourceOfficialWechat = "official_wxpay"
+ VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
+
+ wechatPaymentResumeTokenType = "wechat_payment_resume"
+
+ paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
+ paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
+
+ paymentResumeTokenTTL = 24 * time.Hour
+ wechatPaymentResumeTokenTTL = 15 * time.Minute
+)
+
+type ResumeTokenClaims struct {
+ OrderID int64 `json:"oid"`
+ UserID int64 `json:"uid,omitempty"`
+ ProviderInstanceID string `json:"pi,omitempty"`
+ ProviderKey string `json:"pk,omitempty"`
+ PaymentType string `json:"pt,omitempty"`
+ CanonicalReturnURL string `json:"ru,omitempty"`
+ IssuedAt int64 `json:"iat"`
+ ExpiresAt int64 `json:"exp,omitempty"`
+}
+
+type WeChatPaymentResumeClaims struct {
+ TokenType string `json:"tk,omitempty"`
+ OpenID string `json:"openid"`
+ PaymentType string `json:"pt,omitempty"`
+ Amount string `json:"amt,omitempty"`
+ OrderType string `json:"ot,omitempty"`
+ PlanID int64 `json:"pid,omitempty"`
+ RedirectTo string `json:"rd,omitempty"`
+ Scope string `json:"scp,omitempty"`
+ IssuedAt int64 `json:"iat"`
+ ExpiresAt int64 `json:"exp,omitempty"`
+}
+
+type PaymentResumeService struct {
+ signingKey []byte
+ verifyKeys [][]byte
+}
+
+type visibleMethodLoadBalancer struct {
+ inner payment.LoadBalancer
+ configService *PaymentConfigService
+}
+
+func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
+ svc := &PaymentResumeService{}
+ if len(signingKey) > 0 {
+ svc.signingKey = append([]byte(nil), signingKey...)
+ svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
+ }
+ for _, fallback := range verifyFallbacks {
+ if len(fallback) == 0 {
+ continue
+ }
+ cloned := append([]byte(nil), fallback...)
+ duplicate := false
+ for _, existing := range svc.verifyKeys {
+ if bytes.Equal(existing, cloned) {
+ duplicate = true
+ break
+ }
+ }
+ if !duplicate {
+ svc.verifyKeys = append(svc.verifyKeys, cloned)
+ }
+ }
+ return svc
+}
+
+func (s *PaymentResumeService) isSigningConfigured() bool {
+ return s != nil && len(s.signingKey) > 0
+}
+
+func (s *PaymentResumeService) ensureSigningKey() error {
+ if s.isSigningConfigured() {
+ return nil
+ }
+ return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
+}
+
+func NormalizeVisibleMethod(method string) string {
+ return payment.GetBasePaymentType(strings.TrimSpace(method))
+}
+
+func NormalizeVisibleMethods(methods []string) []string {
+ if len(methods) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(methods))
+ out := make([]string, 0, len(methods))
+ for _, method := range methods {
+ normalized := NormalizeVisibleMethod(method)
+ if normalized == "" {
+ continue
+ }
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ out = append(out, normalized)
+ }
+ return out
+}
+
+func NormalizePaymentSource(source string) string {
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case "", PaymentSourceHostedRedirect:
+ return PaymentSourceHostedRedirect
+ case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
+ return PaymentSourceWechatInAppResume
+ default:
+ return strings.TrimSpace(strings.ToLower(source))
+ }
+}
+
+func NormalizeVisibleMethodSource(method, source string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
+ return VisibleMethodSourceOfficialAlipay
+ case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayAlipay
+ }
+ case payment.TypeWxpay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
+ return VisibleMethodSourceOfficialWechat
+ case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayWechat
+ }
+ }
+ return ""
+}
+
+func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
+ switch NormalizeVisibleMethodSource(method, source) {
+ case VisibleMethodSourceOfficialAlipay:
+ return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceEasyPayAlipay:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceOfficialWechat:
+ return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ case VisibleMethodSourceEasyPayWechat:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ default:
+ return "", false
+ }
+}
+
+func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
+ if inner == nil || configService == nil || configService.entClient == nil {
+ return inner
+ }
+ return &visibleMethodLoadBalancer{inner: inner, configService: configService}
+}
+
+func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
+ return lb.inner.GetInstanceConfig(ctx, instanceID)
+}
+
+func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
+ return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
+ }
+
+ inst, err := lb.configService.resolveEnabledVisibleMethodInstance(ctx, visibleMethod)
+ if err != nil {
+ return nil, err
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("visible payment method %s has no enabled provider instance", visibleMethod)
+ }
+ return lb.inner.SelectInstance(ctx, inst.ProviderKey, paymentType, strategy, orderAmount)
+}
+
+func visibleMethodEnabledSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipayEnabled
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpayEnabled
+ default:
+ return ""
+ }
+}
+
+func visibleMethodSourceSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipaySource
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpaySource
+ default:
+ return ""
+ }
+}
+
+func CanonicalizeReturnURL(raw string, srcHost string, srcURL string) (string, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return "", nil
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil || !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
+ }
+ if parsed.Scheme != "http" && parsed.Scheme != "https" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
+ }
+ parsed.Fragment = ""
+ if parsed.Path == "" {
+ parsed.Path = "/"
+ }
+ if parsed.Path != paymentResultReturnPath {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
+ }
+ if !allowedReturnURLHost(parsed.Host, srcHost, srcURL) {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site or browser origin")
+ }
+ return parsed.String(), nil
+}
+
+func allowedReturnURLHost(returnURLHost string, requestHost string, refererURL string) bool {
+ if sameOriginHost(returnURLHost, requestHost) {
+ return true
+ }
+
+ refererURL = strings.TrimSpace(refererURL)
+ if refererURL == "" {
+ return false
+ }
+ parsedReferer, err := url.Parse(refererURL)
+ if err != nil || parsedReferer.Host == "" {
+ return false
+ }
+ return sameOriginHost(returnURLHost, parsedReferer.Host)
+}
+
+func buildPaymentReturnURL(base string, orderID int64, outTradeNo string, resumeToken string) (string, error) {
+ canonical := strings.TrimSpace(base)
+ if canonical == "" {
+ return "", nil
+ }
+
+ parsed, err := url.Parse(canonical)
+ if err != nil {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL")
+ }
+ if !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL")
+ }
+ parsed.Fragment = ""
+
+ query := parsed.Query()
+ if orderID > 0 {
+ query.Set("order_id", strconv.FormatInt(orderID, 10))
+ }
+ if strings.TrimSpace(outTradeNo) != "" {
+ query.Set("out_trade_no", strings.TrimSpace(outTradeNo))
+ }
+ if strings.TrimSpace(resumeToken) != "" {
+ query.Set("resume_token", strings.TrimSpace(resumeToken))
+ }
+ query.Set("status", "success")
+ parsed.RawQuery = query.Encode()
+
+ return parsed.String(), nil
+}
+
+func sameOriginHost(returnURLHost string, requestHost string) bool {
+ returnHost := strings.TrimSpace(returnURLHost)
+ reqHost := strings.TrimSpace(requestHost)
+ if returnHost == "" || reqHost == "" {
+ return false
+ }
+ if strings.EqualFold(returnHost, reqHost) {
+ return true
+ }
+
+ returnName, returnPort := splitHostPortDefault(returnHost)
+ reqName, reqPort := splitHostPortDefault(reqHost)
+ if returnName == "" || reqName == "" {
+ return false
+ }
+ return strings.EqualFold(returnName, reqName) && returnPort == reqPort
+}
+
+func splitHostPortDefault(raw string) (string, string) {
+ if host, port, err := net.SplitHostPort(raw); err == nil {
+ return host, port
+ }
+ return raw, ""
+}
+
+func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return "", err
+ }
+ if claims.OrderID <= 0 {
+ return "", fmt.Errorf("resume token requires order id")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ if claims.ExpiresAt == 0 {
+ claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix()
+ }
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return nil, err
+ }
+ var claims ResumeTokenClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ }
+ if claims.OrderID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
+ }
+ if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil {
+ return nil, err
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return "", err
+ }
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return "", fmt.Errorf("wechat payment resume token requires openid")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ if claims.ExpiresAt == 0 {
+ claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix()
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ claims.TokenType = wechatPaymentResumeTokenType
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return nil, err
+ }
+ var claims WeChatPaymentResumeClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
+ }
+ if claims.TokenType != wechatPaymentResumeTokenType {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch")
+ }
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+ if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil {
+ return nil, err
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) createSignedToken(claims any) (string, error) {
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ return "", fmt.Errorf("marshal resume claims: %w", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ return encodedPayload + "." + s.sign(encodedPayload), nil
+}
+
+func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
+ parts := strings.Split(token, ".")
+ if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
+ }
+ if !s.verifySignature(parts[0], parts[1]) {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
+ }
+ payload, err := base64.RawURLEncoding.DecodeString(parts[0])
+ if err != nil {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
+ }
+ return json.Unmarshal(payload, dest)
+}
+
+func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
+ if s == nil {
+ return false
+ }
+ for _, key := range s.verifyKeys {
+ if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
+ return true
+ }
+ }
+ return false
+}
+
+func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
+ if expiresAt <= 0 {
+ return nil
+ }
+ if time.Now().Unix() > expiresAt {
+ return infraerrors.BadRequest(code, message)
+ }
+ return nil
+}
+
+func (s *PaymentResumeService) sign(payload string) string {
+ return signPaymentResumePayload(payload, s.signingKey)
+}
+
+func signPaymentResumePayload(payload string, key []byte) string {
+ mac := hmac.New(sha256.New, key)
+ _, _ = mac.Write([]byte(payload))
+ return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+}
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
new file mode 100644
index 00000000..7e0adc2d
--- /dev/null
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -0,0 +1,808 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "net/url"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func TestNormalizeVisibleMethods(t *testing.T) {
+ t.Parallel()
+
+ got := NormalizeVisibleMethods([]string{
+ "alipay_direct",
+ "alipay",
+ " wxpay_direct ",
+ "wxpay",
+ "stripe",
+ })
+
+ want := []string{"alipay", "wxpay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestNormalizePaymentSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expect string
+ }{
+ {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
+ {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
+ {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizePaymentSource(tt.input); got != tt.expect {
+ t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
+ }
+ })
+ }
+}
+
+func TestCanonicalizeReturnURL(t *testing.T) {
+ t.Parallel()
+
+ got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com", "")
+ if err != nil {
+ t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
+ }
+ if got != "https://example.com/payment/result?b=2" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("/payment/result", "example.com", ""); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject relative URLs")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com", ""); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject external hosts")
+ }
+}
+
+func TestCanonicalizeReturnURLAllowsConfiguredFrontendHost(t *testing.T) {
+ t.Parallel()
+
+ got, err := CanonicalizeReturnURL(
+ "https://app.example.com/payment/result?from=checkout",
+ "api.example.com",
+ "https://app.example.com/purchase",
+ )
+ if err != nil {
+ t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
+ }
+ if got != "https://app.example.com/payment/result?from=checkout" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://app.example.com/payment/result?from=checkout")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com", ""); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
+ }
+}
+
+func TestBuildPaymentReturnURL(t *testing.T) {
+ t.Parallel()
+
+ got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "sub2_42", "resume-token")
+ if err != nil {
+ t.Fatalf("buildPaymentReturnURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(got)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ if parsed.Fragment != "" {
+ t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment)
+ }
+ query := parsed.Query()
+ if query.Get("from") != "checkout" {
+ t.Fatalf("expected original query to be preserved, got %q", query.Get("from"))
+ }
+ if query.Get("order_id") != strconv.FormatInt(42, 10) {
+ t.Fatalf("order_id = %q", query.Get("order_id"))
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
+ }
+ if query.Get("resume_token") != "resume-token" {
+ t.Fatalf("resume_token = %q", query.Get("resume_token"))
+ }
+ if query.Get("status") != "success" {
+ t.Fatalf("status = %q", query.Get("status"))
+ }
+}
+
+func TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo(t *testing.T) {
+ t.Parallel()
+
+ got, err := buildPaymentReturnURL("https://example.com/payment/result", 42, "sub2_42", "")
+ if err != nil {
+ t.Fatalf("buildPaymentReturnURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(got)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q", query.Get("order_id"))
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
+ }
+ if query.Get("resume_token") != "" {
+ t.Fatalf("resume_token = %q, want empty", query.Get("resume_token"))
+ }
+}
+
+func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
+ t.Parallel()
+
+ got, err := buildPaymentReturnURL("", 42, "sub2_42", "resume-token")
+ if err != nil {
+ t.Fatalf("buildPaymentReturnURL returned error: %v", err)
+ }
+ if got != "" {
+ t.Fatalf("buildPaymentReturnURL = %q, want empty string", got)
+ }
+}
+
+func TestPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ ProviderInstanceID: "19",
+ ProviderKey: "easypay",
+ PaymentType: "wxpay",
+ CanonicalReturnURL: "https://example.com/payment/result",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseToken(token)
+ if err != nil {
+ t.Fatalf("ParseToken returned error: %v", err)
+ }
+ if claims.OrderID != 42 || claims.UserID != 7 {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
+ t.Fatalf("claims provider snapshot mismatch: %+v", claims)
+ }
+ if claims.CanonicalReturnURL != "https://example.com/payment/result" {
+ t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
+ }
+}
+
+func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
+ if err == nil {
+ t.Fatal("CreateToken should reject missing signing key")
+ }
+}
+
+func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
+ t.Parallel()
+
+ token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.ParseToken(token)
+ if err == nil {
+ t.Fatal("ParseToken should reject tokens when signing key is missing")
+ }
+}
+
+func TestParseTokenRejectsExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ IssuedAt: time.Now().Add(-25 * time.Hour).Unix(),
+ ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ _, err = svc.ParseToken(token)
+ if err == nil {
+ t.Fatal("ParseToken should reject expired tokens")
+ }
+}
+
+func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ RedirectTo: "/purchase?from=wechat",
+ Scope: "snsapi_base",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
+ t.Fatalf("claims payment context mismatch: %+v", claims)
+ }
+ if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
+ t.Fatalf("claims redirect/scope mismatch: %+v", claims)
+ }
+}
+
+func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
+ if err == nil {
+ t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
+ }
+}
+
+func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
+ t.Parallel()
+
+ token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
+ TokenType: wechatPaymentResumeTokenType,
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ })
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err == nil {
+ t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
+ }
+}
+
+func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ IssuedAt: time.Now().Add(-30 * time.Minute).Unix(),
+ ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(),
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ _, err = svc.ParseWeChatPaymentResumeToken(token)
+ if err == nil {
+ t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens")
+ }
+}
+
+func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
+
+ token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-explicit-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ },
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-explicit-key" {
+ t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
+ }
+}
+
+func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
+
+ legacyKey := []byte("0123456789abcdef0123456789abcdef")
+ token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-legacy-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ encryptionKey: legacyKey,
+ },
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-legacy-key" {
+ t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
+ }
+}
+
+func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
+
+ legacyKey := []byte("0123456789abcdef0123456789abcdef")
+ svc := newLegacyAwarePaymentResumeService(legacyKey)
+
+ explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-explicit-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if explicitClaims.OpenID != "openid-explicit-key" {
+ t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key")
+ }
+
+ legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-legacy-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if legacyClaims.OpenID != "openid-legacy-key" {
+ t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key")
+ }
+}
+
+func TestNormalizeVisibleMethodSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ input string
+ want string
+ }{
+ {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
+ {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
+ {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
+ {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
+ {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
+ t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodProviderKeyForSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ source string
+ want string
+ ok bool
+ }{
+ {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
+ {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
+ {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
+ {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
+ {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
+ if got != tt.want || ok != tt.ok {
+ t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create alipay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != payment.TypeAlipay {
+ t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
+ }
+}
+
+func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method payment.PaymentType
+ officialName string
+ officialTypes string
+ easyPayName string
+ easyPayTypes string
+ sourceSetting string
+ wantProvider string
+ }{
+ {
+ name: "alipay uses official source",
+ method: payment.TypeAlipay,
+ officialName: "Official Alipay",
+ officialTypes: "alipay",
+ easyPayName: "EasyPay Alipay",
+ easyPayTypes: "alipay",
+ sourceSetting: VisibleMethodSourceOfficialAlipay,
+ wantProvider: payment.TypeAlipay,
+ },
+ {
+ name: "alipay uses easypay source",
+ method: payment.TypeAlipay,
+ officialName: "Official Alipay",
+ officialTypes: "alipay",
+ easyPayName: "EasyPay Alipay",
+ easyPayTypes: "alipay",
+ sourceSetting: VisibleMethodSourceEasyPayAlipay,
+ wantProvider: payment.TypeEasyPay,
+ },
+ {
+ name: "wxpay uses official source",
+ method: payment.TypeWxpay,
+ officialName: "Official WeChat",
+ officialTypes: "wxpay",
+ easyPayName: "EasyPay WeChat",
+ easyPayTypes: "wxpay",
+ sourceSetting: VisibleMethodSourceOfficialWechat,
+ wantProvider: payment.TypeWxpay,
+ },
+ {
+ name: "wxpay uses easypay source",
+ method: payment.TypeWxpay,
+ officialName: "Official WeChat",
+ officialTypes: "wxpay",
+ easyPayName: "EasyPay WeChat",
+ easyPayTypes: "wxpay",
+ sourceSetting: VisibleMethodSourceEasyPayWechat,
+ wantProvider: payment.TypeEasyPay,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ officialProviderKey := payment.TypeAlipay
+ if tt.method == payment.TypeWxpay {
+ officialProviderKey = payment.TypeWxpay
+ }
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(officialProviderKey).
+ SetName(tt.officialName).
+ SetConfig("{}").
+ SetSupportedTypes(tt.officialTypes).
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official provider: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName(tt.easyPayName).
+ SetConfig("{}").
+ SetSupportedTypes(tt.easyPayTypes).
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ visibleMethodSourceSettingKey(tt.method): tt.sourceSetting,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 12.5)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != tt.wantProvider {
+ t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, tt.wantProvider)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official provider: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ visibleMethodSourceSettingKey(payment.TypeAlipay): "",
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 9.9)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != "" {
+ t.Fatalf("lastProviderKey = %q, want legacy cross-provider empty key", inner.lastProviderKey)
+ }
+ if inner.lastPaymentType != payment.TypeAlipay {
+ t.Fatalf("lastPaymentType = %q, want %q", inner.lastPaymentType, payment.TypeAlipay)
+ }
+}
+
+func TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method payment.PaymentType
+ sourceValue string
+ wantMessage string
+ }{
+ {
+ name: "invalid wxpay source",
+ method: payment.TypeWxpay,
+ sourceValue: "stripe",
+ wantMessage: "wxpay source must be one of the supported payment providers",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ officialProviderKey := payment.TypeAlipay
+ officialSupportedTypes := "alipay"
+ officialName := "Official Alipay"
+ easyPaySupportedTypes := "alipay"
+ easyPayName := "EasyPay Alipay"
+ if tt.method == payment.TypeWxpay {
+ officialProviderKey = payment.TypeWxpay
+ officialSupportedTypes = "wxpay"
+ officialName = "Official WeChat"
+ easyPaySupportedTypes = "wxpay"
+ easyPayName = "EasyPay WeChat"
+ }
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(officialProviderKey).
+ SetName(officialName).
+ SetConfig("{}").
+ SetSupportedTypes(officialSupportedTypes).
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official provider: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName(easyPayName).
+ SetConfig("{}").
+ SetSupportedTypes(easyPaySupportedTypes).
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ visibleMethodSourceSettingKey(tt.method): tt.sourceValue,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 9.9)
+ if err == nil {
+ t.Fatal("SelectInstance should reject invalid visible method source configuration")
+ }
+ if infraerrors.Reason(err) != "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE" {
+ t.Fatalf("Reason(err) = %q, want %q", infraerrors.Reason(err), "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE")
+ }
+ if infraerrors.Message(err) != tt.wantMessage {
+ t.Fatalf("Message(err) = %q, want %q", infraerrors.Message(err), tt.wantMessage)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) {
+ t.Parallel()
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: newPaymentConfigServiceTestClient(t),
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
+ t.Fatal("SelectInstance should reject when no enabled provider instance exists")
+ }
+}
+
+type captureLoadBalancer struct {
+ lastProviderKey string
+ lastPaymentType string
+}
+
+func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
+ return map[string]string{}, nil
+}
+
+func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
+ c.lastProviderKey = providerKey
+ c.lastPaymentType = paymentType
+ return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
+}
+
+func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
+ t.Helper()
+
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ t.Fatalf("marshal claims: %v", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ mac := hmac.New(sha256.New, []byte("sub2api-payment-resume"))
+ _, _ = mac.Write([]byte(encodedPayload))
+ signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ return encodedPayload + "." + signature
+}
diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go
index 6fc23f97..97fd76a0 100644
--- a/backend/internal/service/payment_service.go
+++ b/backend/internal/service/payment_service.go
@@ -1,15 +1,18 @@
package service
import (
+ "bytes"
"context"
+ "encoding/hex"
"fmt"
"log/slog"
"math/rand/v2"
+ "os"
+ "strings"
"sync"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
- "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
@@ -45,6 +48,8 @@ const (
orderIDPrefix = "sub2_"
)
+const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
+
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
@@ -65,29 +70,39 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
- UserID int64
- Amount float64
- PaymentType string
- ClientIP string
- IsMobile bool
- SrcHost string
- SrcURL string
- OrderType string
- PlanID int64
+ UserID int64
+ Amount float64
+ PaymentType string
+ OpenID string
+ ClientIP string
+ IsMobile bool
+ IsWeChatBrowser bool
+ SrcHost string
+ SrcURL string
+ ReturnURL string
+ PaymentSource string
+ OrderType string
+ PlanID int64
}
type CreateOrderResponse struct {
- OrderID int64 `json:"order_id"`
- Amount float64 `json:"amount"`
- PayAmount float64 `json:"pay_amount"`
- FeeRate float64 `json:"fee_rate"`
- Status string `json:"status"`
- PaymentType string `json:"payment_type"`
- PayURL string `json:"pay_url,omitempty"`
- QRCode string `json:"qr_code,omitempty"`
- ClientSecret string `json:"client_secret,omitempty"`
- ExpiresAt time.Time `json:"expires_at"`
- PaymentMode string `json:"payment_mode,omitempty"`
+ OrderID int64 `json:"order_id"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ Status string `json:"status"`
+ ResultType payment.CreatePaymentResultType `json:"result_type,omitempty"`
+ PaymentType string `json:"payment_type"`
+ OutTradeNo string `json:"out_trade_no,omitempty"`
+ PayURL string `json:"pay_url,omitempty"`
+ QRCode string `json:"qr_code,omitempty"`
+ ClientSecret string `json:"client_secret,omitempty"`
+ OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
+ JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
+ JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
+ ExpiresAt time.Time `json:"expires_at"`
+ PaymentMode string `json:"payment_mode,omitempty"`
+ ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
@@ -165,10 +180,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
+ resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
- return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+ svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+ svc.resumeService = psNewPaymentResumeService(configService)
+ return svc
}
// --- Provider Registry ---
@@ -219,25 +237,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) {
}
}
-// GetWebhookProvider returns the provider instance that should verify a webhook.
-// It extracts out_trade_no from the raw body, looks up the order to find the
-// original provider instance, and creates a provider with that instance's credentials.
-// Falls back to the registry provider when the order cannot be found.
-func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
- if outTradeNo != "" {
- order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
- if err == nil {
- p, pErr := s.getOrderProvider(ctx, order)
- if pErr == nil {
- return p, nil
- }
- slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
- }
- }
- s.EnsureProviders(ctx)
- return s.registry.GetProviderByKey(providerKey)
-}
-
// --- Helpers ---
func psIsRefundStatus(s string) bool {
@@ -262,6 +261,60 @@ func psNilIfEmpty(s string) *string {
return &s
}
+func (s *PaymentService) paymentResume() *PaymentResumeService {
+ if s.resumeService != nil {
+ return s.resumeService
+ }
+ return psNewPaymentResumeService(s.configService)
+}
+
+func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
+ return newLegacyAwarePaymentResumeService(legacyKey)
+}
+
+func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
+ return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService))
+}
+
+func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
+ signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey)
+ return NewPaymentResumeService(signingKey, verifyFallbacks...)
+}
+
+func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
+ if configService == nil {
+ return nil
+ }
+ return configService.encryptionKey
+}
+
+func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) {
+ signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
+ if len(signingKey) == 0 {
+ if len(legacyKey) == 0 {
+ return nil, nil
+ }
+ return legacyKey, nil
+ }
+ if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
+ return signingKey, nil
+ }
+ return signingKey, [][]byte{legacyKey}
+}
+
+func parsePaymentResumeSigningKey(raw string) []byte {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return nil
+ }
+ if len(raw) >= 64 && len(raw)%2 == 0 {
+ if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
+ return decoded
+ }
+ }
+ return []byte(raw)
+}
+
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go
new file mode 100644
index 00000000..899bd7a0
--- /dev/null
+++ b/backend/internal/service/payment_visible_method_instances.go
@@ -0,0 +1,242 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func enabledVisibleMethodsForProvider(providerKey, supportedTypes string) []string {
+ methodSet := make(map[string]struct{}, 2)
+ addMethod := func(method string) {
+ method = NormalizeVisibleMethod(method)
+ switch method {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ methodSet[method] = struct{}{}
+ }
+ }
+
+ switch strings.TrimSpace(providerKey) {
+ case payment.TypeAlipay:
+ if strings.TrimSpace(supportedTypes) == "" {
+ addMethod(payment.TypeAlipay)
+ break
+ }
+ for _, supportedType := range splitTypes(supportedTypes) {
+ if NormalizeVisibleMethod(supportedType) == payment.TypeAlipay {
+ addMethod(payment.TypeAlipay)
+ break
+ }
+ }
+ case payment.TypeWxpay:
+ if strings.TrimSpace(supportedTypes) == "" {
+ addMethod(payment.TypeWxpay)
+ break
+ }
+ for _, supportedType := range splitTypes(supportedTypes) {
+ if NormalizeVisibleMethod(supportedType) == payment.TypeWxpay {
+ addMethod(payment.TypeWxpay)
+ break
+ }
+ }
+ case payment.TypeEasyPay:
+ for _, supportedType := range splitTypes(supportedTypes) {
+ addMethod(supportedType)
+ }
+ }
+
+ methods := make([]string, 0, len(methodSet))
+ for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ if _, ok := methodSet[method]; ok {
+ methods = append(methods, method)
+ }
+ }
+ return methods
+}
+
+func providerSupportsVisibleMethod(inst *dbent.PaymentProviderInstance, method string) bool {
+ if inst == nil || !inst.Enabled {
+ return false
+ }
+ method = NormalizeVisibleMethod(method)
+ for _, candidate := range enabledVisibleMethodsForProvider(inst.ProviderKey, inst.SupportedTypes) {
+ if candidate == method {
+ return true
+ }
+ }
+ return false
+}
+
+func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInstance, method string) []*dbent.PaymentProviderInstance {
+ filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
+ for _, inst := range instances {
+ if providerSupportsVisibleMethod(inst, method) {
+ filtered = append(filtered, inst)
+ }
+ }
+ return filtered
+}
+
+func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
+ filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
+ for _, inst := range instances {
+ if !providerSupportsVisibleMethod(inst, method) {
+ continue
+ }
+ if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
+ continue
+ }
+ filtered = append(filtered, inst)
+ }
+ return filtered
+}
+
+func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
+ seen := make(map[string]struct{}, len(instances))
+ keys := make([]string, 0, len(instances))
+ for _, inst := range instances {
+ if inst == nil {
+ continue
+ }
+ key := strings.TrimSpace(inst.ProviderKey)
+ if key == "" {
+ continue
+ }
+ normalized := strings.ToLower(key)
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ keys = append(keys, key)
+ }
+ return keys
+}
+
+func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return nil
+ }
+ for _, inst := range instances {
+ if strings.EqualFold(strings.TrimSpace(inst.ProviderKey), providerKey) {
+ return inst
+ }
+ }
+ return nil
+}
+
+func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
+ ctx context.Context,
+ excludeID int64,
+ providerKey string,
+ supportedTypes string,
+ enabled bool,
+) error {
+ // Visible methods are selected by configured source (official/easypay),
+ // so multiple enabled providers can intentionally claim the same user-facing
+ // method. Order creation and limits will route through the configured source.
+ _, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
+ return nil
+}
+
+func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context.Context, method string) (string, error) {
+ method = NormalizeVisibleMethod(method)
+ sourceKey := visibleMethodSourceSettingKey(method)
+ rawSource := ""
+ if s != nil && s.settingRepo != nil && sourceKey != "" {
+ value, err := s.settingRepo.GetValue(ctx, sourceKey)
+ if err != nil {
+ if !errors.Is(err, ErrSettingNotFound) {
+ return "", fmt.Errorf("get %s: %w", sourceKey, err)
+ }
+ } else {
+ rawSource = value
+ }
+ }
+
+ normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true)
+ if err != nil {
+ return "", err
+ }
+ if normalizedSource == "" {
+ return "", nil
+ }
+ providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource)
+ if !ok {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source must be one of the supported payment providers", method),
+ )
+ }
+ return providerKey, nil
+}
+
+func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
+ ctx context.Context,
+ method string,
+ matching []*dbent.PaymentProviderInstance,
+) (string, error) {
+ switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
+ case 0:
+ return "", nil
+ case 1:
+ return strings.TrimSpace(providerKeys[0]), nil
+ default:
+ providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
+ if err != nil {
+ return "", err
+ }
+ if providerKey == "" {
+ return "", nil
+ }
+ selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
+ if selected == nil {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source has no enabled provider instance", method),
+ )
+ }
+ return strings.TrimSpace(selected.ProviderKey), nil
+ }
+}
+
+func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
+ ctx context.Context,
+ method string,
+) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil {
+ return nil, nil
+ }
+
+ method = NormalizeVisibleMethod(method)
+ if method != payment.TypeAlipay && method != payment.TypeWxpay {
+ return nil, nil
+ }
+
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.EnabledEQ(true)).
+ Order(paymentproviderinstance.BySortOrder()).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query enabled payment providers: %w", err)
+ }
+
+ matching := filterEnabledVisibleMethodInstances(instances, method)
+ providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
+ if err != nil {
+ return nil, err
+ }
+ if providerKey == "" {
+ if len(matching) == 0 {
+ return nil, nil
+ }
+ return &dbent.PaymentProviderInstance{ProviderKey: ""}, nil
+ }
+ return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
+}
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
new file mode 100644
index 00000000..f2da40d9
--- /dev/null
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -0,0 +1,148 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+// GetWebhookProvider returns the provider instance that should verify a webhook.
+// It resolves the original provider instance from the order whenever possible and
+// only falls back to a registry provider for legacy/single-instance scenarios.
+func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
+ providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
+ if err != nil {
+ return nil, err
+ }
+ if len(providers) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+ return providers[0], nil
+}
+
+// GetWebhookProviders returns provider candidates that can verify the webhook.
+// Official WeChat Pay may require multiple candidates because the callback body
+// cannot be bound to a merchant before decryption.
+func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
+ if outTradeNo != "" {
+ order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
+ if err == nil {
+ if psHasPinnedProviderInstance(order) {
+ prov, err := s.getPinnedOrderProvider(ctx, order)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ inst, err := s.getOrderProviderInstance(ctx, order)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst != nil {
+ prov, err := s.createProviderFromInstance(ctx, inst)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ if strings.TrimSpace(providerKey) == payment.TypeWxpay {
+ return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
+ }
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProviderByKey(providerKey)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ }
+
+ if strings.TrimSpace(providerKey) == payment.TypeWxpay {
+ return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
+ }
+
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
+ }
+
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProviderByKey(providerKey)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+}
+
+func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
+ inst, err := s.getOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("order %d provider instance is missing", o.ID)
+ }
+ return s.createProviderFromInstance(ctx, inst)
+}
+
+func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool {
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" || s == nil || s.entClient == nil {
+ return false
+ }
+
+ count, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ paymentproviderinstance.EnabledEQ(true),
+ ).
+ Count(ctx)
+ if err != nil {
+ slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err)
+ return false
+ }
+ return count <= 1
+}
+
+func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
+ return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
+}
+
+func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
+ providerKey = strings.TrimSpace(providerKey)
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ paymentproviderinstance.EnabledEQ(true),
+ ).
+ Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query webhook provider instances: %w", err)
+ }
+ if len(instances) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+
+ providers := make([]payment.Provider, 0, len(instances))
+ for _, inst := range instances {
+ prov, provErr := s.createProviderFromInstance(ctx, inst)
+ if provErr != nil {
+ slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
+ continue
+ }
+ providers = append(providers, prov)
+ }
+ if len(providers) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+ return providers, nil
+}
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
new file mode 100644
index 00000000..0f3efa1f
--- /dev/null
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -0,0 +1,510 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "strconv"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef"
+
+type webhookProviderTestDouble struct {
+ key string
+ types []payment.PaymentType
+}
+
+func (p webhookProviderTestDouble) Name() string { return p.key }
+func (p webhookProviderTestDouble) ProviderKey() string { return p.key }
+func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types }
+func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string {
+ t.Helper()
+
+ data, err := json.Marshal(config)
+ require.NoError(t, err)
+
+ encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey))
+ require.NoError(t, err)
+ return encrypted
+}
+
+func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer {
+ return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey))
+}
+
+func encryptValidWebhookWxpayConfig(t *testing.T, suffix string) string {
+ t.Helper()
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ require.NoError(t, err)
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ require.NoError(t, err)
+
+ return encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-" + suffix,
+ "mchId": "mch-" + suffix,
+ "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
+ "publicKeyId": "public-key-id-" + suffix,
+ "certSerial": "cert-serial-" + suffix,
+ })
+}
+
+func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-a").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeStripe
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeStripe,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpayDirect,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("easypay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-disabled-legacy").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(false).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-enabled-current").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeStripe
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeStripe,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-only").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeWxpay
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipayDirect,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-snapshot").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 42,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(inst.ID, 10),
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-legacy-fallback").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 43,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "999999",
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.Nil(t, got)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
+}
+
+func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ wxpayConfigA := encryptValidWebhookWxpayConfig(t, "a")
+ wxpayConfigB := encryptValidWebhookWxpayConfig(t, "b")
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig(wxpayConfigA).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-b").
+ SetConfig(wxpayConfigB).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
+ require.NoError(t, err)
+ require.Len(t, providers, 2)
+}
+
+func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-a").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-b").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ _, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "ambiguous")
+}
+
+func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-a").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ registry.Register(webhookProviderTestDouble{
+ key: payment.TypeStripe,
+ types: []payment.PaymentType{payment.TypeStripe},
+ })
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ prov := providers[0]
+ require.Equal(t, payment.TypeStripe, prov.ProviderKey())
+}
+
+func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ pinnedInstanceID := "999"
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("TEST-RECHARGE").
+ SetOutTradeNo("sub2_test_pinned_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(pinnedInstanceID).
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ registry.Register(webhookProviderTestDouble{
+ key: payment.TypeWxpay,
+ types: []payment.PaymentType{payment.TypeWxpay},
+ })
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ _, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider instance")
+}
+
+func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("snapshot-webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ wxpayConfigA := encryptValidWebhookWxpayConfig(t, "snapshot-a")
+ wxpayConfigB := encryptValidWebhookWxpayConfig(t, "snapshot-b")
+ instA, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-a").
+ SetConfig(wxpayConfigA).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-b").
+ SetConfig(wxpayConfigB).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(66).
+ SetPayAmount(66).
+ SetFeeRate(0).
+ SetRechargeCode("SNAPSHOT-WEBHOOK").
+ SetOutTradeNo("sub2_test_snapshot_webhook_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(instA.ID, 10),
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "native",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 7f4a2eb1..f2b644be 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -114,6 +114,253 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
}
+type ProviderDefaultGrantSettings struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+ GrantOnSignup bool
+ GrantOnFirstBind bool
+}
+
+type AuthSourceDefaultSettings struct {
+ Email ProviderDefaultGrantSettings
+ LinuxDo ProviderDefaultGrantSettings
+ OIDC ProviderDefaultGrantSettings
+ WeChat ProviderDefaultGrantSettings
+ ForceEmailOnThirdPartySignup bool
+}
+
+type authSourceDefaultKeySet struct {
+ balance string
+ concurrency string
+ subscriptions string
+ grantOnSignup string
+ grantOnFirstBind string
+}
+
+var (
+ emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultEmailBalance,
+ concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ }
+ linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
+ concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ }
+ oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultOIDCBalance,
+ concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ }
+ weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultWeChatBalance,
+ concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ }
+)
+
+const (
+ defaultAuthSourceBalance = 0
+ defaultAuthSourceConcurrency = 5
+ defaultWeChatConnectMode = "open"
+ defaultWeChatConnectScopes = "snsapi_login"
+ defaultWeChatConnectFrontend = "/auth/wechat/callback"
+)
+
+func normalizeWeChatConnectModeSetting(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "mp":
+ return "mp"
+ case "mobile":
+ return "mobile"
+ default:
+ return "open"
+ }
+}
+
+func defaultWeChatConnectScopeForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return "snsapi_userinfo"
+ case "mobile":
+ return ""
+ }
+ return defaultWeChatConnectScopes
+}
+
+func normalizeWeChatConnectScopeSetting(raw, mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ switch strings.TrimSpace(raw) {
+ case "snsapi_base":
+ return "snsapi_base"
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ default:
+ return defaultWeChatConnectScopeForMode(mode)
+ }
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bool, mode string) (bool, bool, bool) {
+ mode = normalizeWeChatConnectModeSetting(mode)
+ rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled]
+ rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled]
+ rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled]
+ openConfigured := hasOpen && strings.TrimSpace(rawOpen) != ""
+ mpConfigured := hasMP && strings.TrimSpace(rawMP) != ""
+ mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != ""
+
+ if openConfigured || mpConfigured || mobileConfigured {
+ openEnabled := strings.TrimSpace(rawOpen) == "true"
+ mpEnabled := strings.TrimSpace(rawMP) == "true"
+ mobileEnabled := strings.TrimSpace(rawMobile) == "true"
+ return openEnabled, mpEnabled, mobileEnabled
+ }
+
+ if !enabled {
+ return false, false, false
+ }
+ if mode == "mp" {
+ return false, true, false
+ }
+ if mode == "mobile" {
+ return false, false, true
+ }
+ return true, false, false
+}
+
+func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
+ mode = normalizeWeChatConnectModeSetting(mode)
+ switch mode {
+ case "open":
+ if openEnabled {
+ return "open"
+ }
+ case "mp":
+ if mpEnabled {
+ return "mp"
+ }
+ case "mobile":
+ if mobileEnabled {
+ return "mobile"
+ }
+ }
+ switch {
+ case openEnabled:
+ return "open"
+ case mpEnabled:
+ return "mp"
+ case mobileEnabled:
+ return "mobile"
+ default:
+ return mode
+ }
+}
+
+func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) {
+ mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode))
+ rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled]
+ rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled]
+ rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled]
+ openConfigured := hasOpen && strings.TrimSpace(rawOpen) != ""
+ mpConfigured := hasMP && strings.TrimSpace(rawMP) != ""
+ mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != ""
+
+ if openConfigured || mpConfigured || mobileConfigured {
+ openEnabled := strings.TrimSpace(rawOpen) == "true"
+ mpEnabled := strings.TrimSpace(rawMP) == "true"
+ mobileEnabled := strings.TrimSpace(rawMobile) == "true"
+ _, enabledConfigured := settings[SettingKeyWeChatConnectEnabled]
+ if !enabledConfigured &&
+ enabled &&
+ !openEnabled &&
+ !mpEnabled &&
+ !mobileEnabled &&
+ (base.OpenEnabled || base.MPEnabled || base.MobileEnabled) {
+ return base.OpenEnabled, base.MPEnabled, base.MobileEnabled
+ }
+ return openEnabled, mpEnabled, mobileEnabled
+ }
+ if !enabled {
+ return false, false, false
+ }
+ if base.OpenEnabled || base.MPEnabled || base.MobileEnabled {
+ return base.OpenEnabled, base.MPEnabled, base.MobileEnabled
+ }
+ return parseWeChatConnectCapabilitySettings(settings, enabled, mode)
+}
+
+func (s *SettingService) effectiveWeChatConnectOAuthConfig(settings map[string]string) WeChatConnectOAuthConfig {
+ base := config.WeChatConnectConfig{}
+ if s != nil && s.cfg != nil {
+ base = s.cfg.WeChat
+ }
+
+ enabled := base.Enabled
+ if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok {
+ enabled = strings.TrimSpace(raw) == "true"
+ }
+
+ legacyAppID := strings.TrimSpace(firstNonEmpty(
+ settings[SettingKeyWeChatConnectAppID],
+ base.AppID,
+ base.OpenAppID,
+ base.MPAppID,
+ base.MobileAppID,
+ ))
+ legacyAppSecret := strings.TrimSpace(firstNonEmpty(
+ settings[SettingKeyWeChatConnectAppSecret],
+ base.AppSecret,
+ base.OpenAppSecret,
+ base.MPAppSecret,
+ base.MobileAppSecret,
+ ))
+ openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], base.OpenAppID, legacyAppID))
+ openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], base.OpenAppSecret, legacyAppSecret))
+ mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], base.MPAppID, legacyAppID))
+ mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], base.MPAppSecret, legacyAppSecret))
+ mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], base.MobileAppID, legacyAppID))
+ mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], base.MobileAppSecret, legacyAppSecret))
+
+ modeRaw := firstNonEmpty(settings[SettingKeyWeChatConnectMode], base.Mode)
+ openEnabled, mpEnabled, mobileEnabled := mergeWeChatConnectCapabilitySettings(settings, base, enabled, modeRaw)
+ mode := normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, modeRaw)
+
+ return WeChatConnectOAuthConfig{
+ Enabled: enabled,
+ LegacyAppID: legacyAppID,
+ LegacyAppSecret: legacyAppSecret,
+ OpenAppID: openAppID,
+ OpenAppSecret: openAppSecret,
+ MPAppID: mpAppID,
+ MPAppSecret: mpAppSecret,
+ MobileAppID: mobileAppID,
+ MobileAppSecret: mobileAppSecret,
+ OpenEnabled: openEnabled,
+ MPEnabled: mpEnabled,
+ MobileEnabled: mobileEnabled,
+ Mode: mode,
+ Scopes: normalizeWeChatConnectScopeSetting(firstNonEmpty(settings[SettingKeyWeChatConnectScopes], base.Scopes), mode),
+ RedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectRedirectURL], base.RedirectURL)),
+ FrontendRedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectFrontendRedirectURL], base.FrontendRedirectURL, defaultWeChatConnectFrontend)),
+ }
+}
+
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@@ -156,6 +403,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
+ SettingKeyForceEmailOnThirdPartySignup,
SettingKeyRegistrationEmailSuffixWhitelist,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
@@ -178,6 +426,22 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectAppID,
+ SettingKeyWeChatConnectAppSecret,
+ SettingKeyWeChatConnectOpenAppID,
+ SettingKeyWeChatConnectOpenAppSecret,
+ SettingKeyWeChatConnectMPAppID,
+ SettingKeyWeChatConnectMPAppSecret,
+ SettingKeyWeChatConnectMobileAppID,
+ SettingKeyWeChatConnectMobileAppSecret,
+ SettingKeyWeChatConnectOpenEnabled,
+ SettingKeyWeChatConnectMPEnabled,
+ SettingKeyWeChatConnectMobileEnabled,
+ SettingKeyWeChatConnectMode,
+ SettingKeyWeChatConnectScopes,
+ SettingKeyWeChatConnectRedirectURL,
+ SettingKeyWeChatConnectFrontendRedirectURL,
SettingKeyBackendModeEnabled,
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
@@ -212,6 +476,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
+ weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -232,6 +497,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,
@@ -254,6 +520,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
+ WeChatOAuthEnabled: weChatEnabled,
+ WeChatOAuthOpenEnabled: weChatOpenEnabled,
+ WeChatOAuthMPEnabled: weChatMPEnabled,
+ WeChatOAuthMobileEnabled: weChatMobileEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
@@ -310,6 +580,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
@@ -344,6 +618,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
+ WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
+ WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
+ WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
@@ -356,6 +634,64 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
}, nil
}
+func DefaultWeChatConnectScopesForMode(mode string) string {
+ return defaultWeChatConnectScopeForMode(mode)
+}
+
+func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) {
+ cfg := s.effectiveWeChatConnectOAuthConfig(settings)
+
+ if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) {
+ return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+ if cfg.OpenEnabled {
+ if cfg.AppIDForMode("open") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app id not configured")
+ }
+ if cfg.AppSecretForMode("open") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app secret not configured")
+ }
+ }
+ if cfg.MPEnabled {
+ if cfg.AppIDForMode("mp") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app id not configured")
+ }
+ if cfg.AppSecretForMode("mp") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app secret not configured")
+ }
+ }
+ if cfg.MobileEnabled {
+ if cfg.AppIDForMode("mobile") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app id not configured")
+ }
+ if cfg.AppSecretForMode("mobile") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured")
+ }
+ }
+ if v := strings.TrimSpace(cfg.RedirectURL); v != "" {
+ if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid")
+ }
+ }
+ if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid")
+ }
+ return cfg, nil
+}
+
+func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) {
+ cfg := s.effectiveWeChatConnectOAuthConfig(settings)
+ if !cfg.Enabled {
+ return false, false, false, false
+ }
+
+ openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != ""
+ mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != ""
+ mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != ""
+
+ return openReady || mpReady, openReady, mpReady, mobileReady
+}
+
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
// array string, returning only items with visibility != "admin".
func filterUserVisibleMenuItems(raw string) json.RawMessage {
@@ -478,19 +814,130 @@ func parseCustomMenuItemURLs(raw string) []string {
return urls
}
+func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
+ if base.UsePKCEExplicit {
+ return base.UsePKCE
+ }
+ return true
+}
+
+func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
+ if base.ValidateIDTokenExplicit {
+ return base.ValidateIDToken
+ }
+ return true
+}
+
+func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool {
+ if configured {
+ return strings.TrimSpace(raw) == "true"
+ }
+ if explicit {
+ return explicitValue
+ }
+ return false
+}
+
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
- if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ updates, err := s.buildSystemSettingsUpdates(ctx, settings)
+ if err != nil {
return err
}
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ s.refreshCachedSettings(settings)
+ }
+ return err
+}
+
+func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, bool, error) {
+ rawSettings, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingKeyOIDCConnectUsePKCE,
+ SettingKeyOIDCConnectValidateIDToken,
+ })
+ if err != nil {
+ return false, false, fmt.Errorf("get oidc security write defaults: %w", err)
+ }
+
+ base := config.OIDCConnectConfig{}
+ if s != nil && s.cfg != nil {
+ base = s.cfg.OIDC
+ }
+
+ rawUsePKCE, hasUsePKCE := rawSettings[SettingKeyOIDCConnectUsePKCE]
+ rawValidateIDToken, hasValidateIDToken := rawSettings[SettingKeyOIDCConnectValidateIDToken]
+
+ return oidcCompatibilityWriteDefault(base, hasUsePKCE, rawUsePKCE, base.UsePKCEExplicit, base.UsePKCE),
+ oidcCompatibilityWriteDefault(base, hasValidateIDToken, rawValidateIDToken, base.ValidateIDTokenExplicit, base.ValidateIDToken),
+ nil
+}
+
+// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write.
+func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error {
+ updates, err := s.buildSystemSettingsUpdates(ctx, settings)
+ if err != nil {
+ return err
+ }
+
+ authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults)
+ if err != nil {
+ return err
+ }
+ for key, value := range authSourceUpdates {
+ updates[key] = value
+ }
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ s.refreshCachedSettings(settings)
+ }
+ return err
+}
+
+func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) {
+ if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ return nil, err
+ }
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
- return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
+ return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
}
if normalizedWhitelist == nil {
normalizedWhitelist = []string{}
}
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
+ alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
+ if err != nil {
+ return nil, err
+ }
+ wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
+ if err != nil {
+ return nil, err
+ }
+ settings.PaymentVisibleMethodAlipaySource = alipaySource
+ settings.PaymentVisibleMethodWxpaySource = wxpaySource
+ settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID)
+ settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret)
+ settings.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppID, settings.WeChatConnectAppID))
+ settings.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppSecret, settings.WeChatConnectAppSecret))
+ settings.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppID, settings.WeChatConnectAppID))
+ settings.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppSecret, settings.WeChatConnectAppSecret))
+ settings.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppID, settings.WeChatConnectAppID))
+ settings.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppSecret, settings.WeChatConnectAppSecret))
+ settings.WeChatConnectMode = normalizeWeChatConnectStoredMode(
+ settings.WeChatConnectOpenEnabled,
+ settings.WeChatConnectMPEnabled,
+ settings.WeChatConnectMobileEnabled,
+ settings.WeChatConnectMode,
+ )
+ settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode)
+ settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL)
+ settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL)
+ if settings.WeChatConnectFrontendRedirectURL == "" {
+ settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
+ }
updates := make(map[string]string)
@@ -499,7 +946,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
- return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
+ return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err)
}
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
@@ -560,6 +1007,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
+ // WeChat Connect OAuth 登录
+ updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
+ updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
+ updates[SettingKeyWeChatConnectOpenAppID] = settings.WeChatConnectOpenAppID
+ updates[SettingKeyWeChatConnectMPAppID] = settings.WeChatConnectMPAppID
+ updates[SettingKeyWeChatConnectMobileAppID] = settings.WeChatConnectMobileAppID
+ updates[SettingKeyWeChatConnectOpenEnabled] = strconv.FormatBool(settings.WeChatConnectOpenEnabled)
+ updates[SettingKeyWeChatConnectMPEnabled] = strconv.FormatBool(settings.WeChatConnectMPEnabled)
+ updates[SettingKeyWeChatConnectMobileEnabled] = strconv.FormatBool(settings.WeChatConnectMobileEnabled)
+ updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode
+ updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes
+ updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL
+ updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL
+ if settings.WeChatConnectAppSecret != "" {
+ updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret
+ }
+ if settings.WeChatConnectOpenAppSecret != "" {
+ updates[SettingKeyWeChatConnectOpenAppSecret] = settings.WeChatConnectOpenAppSecret
+ }
+ if settings.WeChatConnectMPAppSecret != "" {
+ updates[SettingKeyWeChatConnectMPAppSecret] = settings.WeChatConnectMPAppSecret
+ }
+ if settings.WeChatConnectMobileAppSecret != "" {
+ updates[SettingKeyWeChatConnectMobileAppSecret] = settings.WeChatConnectMobileAppSecret
+ }
+
// OEM设置
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
@@ -578,7 +1051,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
if err != nil {
- return fmt.Errorf("marshal table page size options: %w", err)
+ return nil, fmt.Errorf("marshal table page size options: %w", err)
}
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
@@ -589,7 +1062,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
- return fmt.Errorf("marshal default subscriptions: %w", err)
+ return nil, fmt.Errorf("marshal default subscriptions: %w", err)
}
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
@@ -626,6 +1099,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
+ updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
+ updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
+ updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
+ updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled)
+ updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
@@ -634,32 +1112,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
- err = s.settingRepo.SetMultiple(ctx, updates)
- if err == nil {
- // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
- versionBoundsSF.Forget("version_bounds")
- versionBoundsCache.Store(&cachedVersionBounds{
- min: settings.MinClaudeCodeVersion,
- max: settings.MaxClaudeCodeVersion,
- expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
- })
- backendModeSF.Forget("backend_mode")
- backendModeCache.Store(&cachedBackendMode{
- value: settings.BackendModeEnabled,
- expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
- })
- gatewayForwardingSF.Forget("gateway_forwarding")
- gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: settings.EnableFingerprintUnification,
- metadataPassthrough: settings.EnableMetadataPassthrough,
- cchSigning: settings.EnableCCHSigning,
- expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
- })
- if s.onUpdate != nil {
- s.onUpdate() // Invalidate cache after settings update
+ return updates, nil
+}
+
+func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
+ if settings == nil {
+ return nil, nil
+ }
+
+ for _, subscriptions := range [][]DefaultSubscriptionSetting{
+ settings.Email.Subscriptions,
+ settings.LinuxDo.Subscriptions,
+ settings.OIDC.Subscriptions,
+ settings.WeChat.Subscriptions,
+ } {
+ if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
+ return nil, err
}
}
- return err
+
+ updates := make(map[string]string, 21)
+ writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
+ writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
+ writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
+ writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
+ updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
+ return updates, nil
+}
+
+func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
+ if settings == nil {
+ return
+ }
+
+ // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
+ versionBoundsSF.Forget("version_bounds")
+ versionBoundsCache.Store(&cachedVersionBounds{
+ min: settings.MinClaudeCodeVersion,
+ max: settings.MaxClaudeCodeVersion,
+ expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
+ })
+ backendModeSF.Forget("backend_mode")
+ backendModeCache.Store(&cachedBackendMode{
+ value: settings.BackendModeEnabled,
+ expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
+ })
+ gatewayForwardingSF.Forget("gateway_forwarding")
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
+ fingerprintUnification: settings.EnableFingerprintUnification,
+ metadataPassthrough: settings.EnableMetadataPassthrough,
+ cchSigning: settings.EnableCCHSigning,
+ expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
+ })
+ openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: settings.OpenAIAdvancedSchedulerEnabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ if s.onUpdate != nil {
+ s.onUpdate() // Invalidate cache after settings update
+ }
}
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
@@ -919,6 +1431,88 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value)
}
+func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
+ keys := []string{
+ SettingKeyAuthSourceDefaultEmailBalance,
+ SettingKeyAuthSourceDefaultEmailConcurrency,
+ SettingKeyAuthSourceDefaultEmailSubscriptions,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultLinuxDoBalance,
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultOIDCBalance,
+ SettingKeyAuthSourceDefaultOIDCConcurrency,
+ SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultWeChatBalance,
+ SettingKeyAuthSourceDefaultWeChatConcurrency,
+ SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ SettingKeyForceEmailOnThirdPartySignup,
+ }
+
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return nil, fmt.Errorf("get auth source default settings: %w", err)
+ }
+
+ return &AuthSourceDefaultSettings{
+ Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
+ LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
+ OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
+ WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
+ }, nil
+}
+
+func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) {
+ result := ProviderDefaultGrantSettings{
+ Balance: s.GetDefaultBalance(ctx),
+ Concurrency: s.GetDefaultConcurrency(ctx),
+ Subscriptions: s.GetDefaultSubscriptions(ctx),
+ }
+
+ defaults, err := s.GetAuthSourceDefaultSettings(ctx)
+ if err != nil {
+ return result, false, err
+ }
+
+ providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
+ if !ok {
+ return result, false, nil
+ }
+
+ enabled := providerDefaults.GrantOnSignup
+ if firstBind {
+ enabled = providerDefaults.GrantOnFirstBind
+ }
+ if !enabled {
+ return result, false, nil
+ }
+
+ return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil
+}
+
+func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
+ updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings)
+ if err != nil {
+ return err
+ }
+ if len(updates) == 0 {
+ return nil
+ }
+
+ if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
+ return fmt.Errorf("update auth source default settings: %w", err)
+ }
+ return nil
+}
+
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -931,27 +1525,95 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
return fmt.Errorf("check existing settings: %w", err)
}
+ oidcUsePKCEDefault := true
+ oidcValidateIDTokenDefault := true
+ if s != nil && s.cfg != nil {
+ if s.cfg.OIDC.UsePKCEExplicit {
+ oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
+ }
+ if s.cfg.OIDC.ValidateIDTokenExplicit {
+ oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
+ }
+ }
+
// 初始化默认设置
defaults := map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyEmailVerifyEnabled: "false",
- SettingKeyRegistrationEmailSuffixWhitelist: "[]",
- SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
- SettingKeySiteName: "Sub2API",
- SettingKeySiteLogo: "",
- SettingKeyPurchaseSubscriptionEnabled: "false",
- SettingKeyPurchaseSubscriptionURL: "",
- SettingKeyTableDefaultPageSize: "20",
- SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- SettingKeyCustomMenuItems: "[]",
- SettingKeyCustomEndpoints: "[]",
- SettingKeyOIDCConnectEnabled: "false",
- SettingKeyOIDCConnectProviderName: "OIDC",
- SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
- SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
- SettingKeyDefaultSubscriptions: "[]",
- SettingKeySMTPPort: "587",
- SettingKeySMTPUseTLS: "false",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
+ SettingKeySiteName: "Sub2API",
+ SettingKeySiteLogo: "",
+ SettingKeyPurchaseSubscriptionEnabled: "false",
+ SettingKeyPurchaseSubscriptionURL: "",
+ SettingKeyTableDefaultPageSize: "20",
+ SettingKeyTablePageSizeOptions: "[10,20,50,100]",
+ SettingKeyCustomMenuItems: "[]",
+ SettingKeyCustomEndpoints: "[]",
+ SettingKeyWeChatConnectEnabled: "false",
+ SettingKeyWeChatConnectAppID: "",
+ SettingKeyWeChatConnectAppSecret: "",
+ SettingKeyWeChatConnectOpenAppID: "",
+ SettingKeyWeChatConnectOpenAppSecret: "",
+ SettingKeyWeChatConnectMPAppID: "",
+ SettingKeyWeChatConnectMPAppSecret: "",
+ SettingKeyWeChatConnectMobileAppID: "",
+ SettingKeyWeChatConnectMobileAppSecret: "",
+ SettingKeyWeChatConnectOpenEnabled: "false",
+ SettingKeyWeChatConnectMPEnabled: "false",
+ SettingKeyWeChatConnectMobileEnabled: "false",
+ SettingKeyWeChatConnectMode: "open",
+ SettingKeyWeChatConnectScopes: "snsapi_login",
+ SettingKeyWeChatConnectRedirectURL: "",
+ SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
+ SettingKeyOIDCConnectEnabled: "false",
+ SettingKeyOIDCConnectProviderName: "OIDC",
+ SettingKeyOIDCConnectClientID: "",
+ SettingKeyOIDCConnectClientSecret: "",
+ SettingKeyOIDCConnectIssuerURL: "",
+ SettingKeyOIDCConnectDiscoveryURL: "",
+ SettingKeyOIDCConnectAuthorizeURL: "",
+ SettingKeyOIDCConnectTokenURL: "",
+ SettingKeyOIDCConnectUserInfoURL: "",
+ SettingKeyOIDCConnectJWKSURL: "",
+ SettingKeyOIDCConnectScopes: "openid email profile",
+ SettingKeyOIDCConnectRedirectURL: "",
+ SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
+ SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
+ SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
+ SettingKeyOIDCConnectClockSkewSeconds: "120",
+ SettingKeyOIDCConnectRequireEmailVerified: "false",
+ SettingKeyOIDCConnectUserInfoEmailPath: "",
+ SettingKeyOIDCConnectUserInfoIDPath: "",
+ SettingKeyOIDCConnectUserInfoUsernamePath: "",
+ SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
+ SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyDefaultSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailBalance: "0",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultOIDCBalance: "0",
+ SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
+ SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultWeChatBalance: "0",
+ SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
+ SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
+ SettingKeyForceEmailOnThirdPartySignup: "false",
+ SettingKeySMTPPort: "587",
+ SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -973,7 +1635,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyMaxClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
- SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingPaymentVisibleMethodAlipaySource: "",
+ SettingPaymentVisibleMethodWxpaySource: "",
+ SettingPaymentVisibleMethodAlipayEnabled: "false",
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ openAIAdvancedSchedulerSettingKey: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -1157,12 +1824,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
- result.OIDCConnectUsePKCE = oidcBase.UsePKCE
+ result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
- result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
+ result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
@@ -1208,6 +1875,31 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
+ // WeChat Connect 设置:
+ // - 优先读取 DB 系统设置
+ // - 缺失时回退到 config/env,保持升级兼容
+ weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings)
+ result.WeChatConnectEnabled = weChatEffective.Enabled
+ result.WeChatConnectAppID = weChatEffective.LegacyAppID
+ result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret
+ result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != ""
+ result.WeChatConnectOpenAppID = weChatEffective.OpenAppID
+ result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret
+ result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != ""
+ result.WeChatConnectMPAppID = weChatEffective.MPAppID
+ result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret
+ result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != ""
+ result.WeChatConnectMobileAppID = weChatEffective.MobileAppID
+ result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret
+ result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != ""
+ result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled
+ result.WeChatConnectMPEnabled = weChatEffective.MPEnabled
+ result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled
+ result.WeChatConnectMode = weChatEffective.Mode
+ result.WeChatConnectScopes = weChatEffective.Scopes
+ result.WeChatConnectRedirectURL = weChatEffective.RedirectURL
+ result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL
+
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
@@ -1263,6 +1955,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
+ result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource])
+ result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource])
+ result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true"
+ result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true"
+ result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true"
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
@@ -1292,6 +1989,23 @@ func isFalseSettingValue(value string) bool {
}
}
+func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) {
+ _ = enabled
+ source = strings.TrimSpace(source)
+ if source == "" {
+ return "", nil
+ }
+
+ normalized := NormalizeVisibleMethodSource(method, source)
+ if normalized == "" {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source must be one of the supported payment providers", method),
+ )
+ }
+ return normalized, nil
+}
+
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
raw = strings.TrimSpace(raw)
if raw == "" {
@@ -1317,6 +2031,73 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
+func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: defaultAuthSourceBalance,
+ Concurrency: defaultAuthSourceConcurrency,
+ Subscriptions: []DefaultSubscriptionSetting{},
+ GrantOnSignup: false,
+ GrantOnFirstBind: false,
+ }
+
+ if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
+ result.Balance = v
+ }
+ if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
+ result.Concurrency = v
+ }
+ if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
+ result.Subscriptions = items
+ }
+ if raw, ok := settings[keys.grantOnSignup]; ok {
+ result.GrantOnSignup = raw == "true"
+ }
+ if raw, ok := settings[keys.grantOnFirstBind]; ok {
+ result.GrantOnFirstBind = raw == "true"
+ }
+
+ return result
+}
+
+func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
+ updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
+ updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
+
+ subscriptions := settings.Subscriptions
+ if subscriptions == nil {
+ subscriptions = []DefaultSubscriptionSetting{}
+ }
+ raw, err := json.Marshal(subscriptions)
+ if err != nil {
+ raw = []byte("[]")
+ }
+ updates[keys.subscriptions] = string(raw)
+ updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
+ updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
+}
+
+func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: globalDefaults.Balance,
+ Concurrency: globalDefaults.Concurrency,
+ Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...),
+ GrantOnSignup: providerDefaults.GrantOnSignup,
+ GrantOnFirstBind: providerDefaults.GrantOnFirstBind,
+ }
+
+ if providerDefaults.Balance != defaultAuthSourceBalance {
+ result.Balance = providerDefaults.Balance
+ }
+ if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency {
+ result.Concurrency = providerDefaults.Concurrency
+ }
+ if len(providerDefaults.Subscriptions) > 0 {
+ result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...)
+ }
+
+ return result
+}
+
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
@@ -1539,7 +2320,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
-
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
@@ -1587,9 +2367,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@@ -1597,6 +2374,35 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil
}
+// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。
+//
+// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。
+func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) {
+ keys := []string{
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectAppID,
+ SettingKeyWeChatConnectAppSecret,
+ SettingKeyWeChatConnectOpenAppID,
+ SettingKeyWeChatConnectOpenAppSecret,
+ SettingKeyWeChatConnectMPAppID,
+ SettingKeyWeChatConnectMPAppSecret,
+ SettingKeyWeChatConnectMobileAppID,
+ SettingKeyWeChatConnectMobileAppSecret,
+ SettingKeyWeChatConnectOpenEnabled,
+ SettingKeyWeChatConnectMPEnabled,
+ SettingKeyWeChatConnectMobileEnabled,
+ SettingKeyWeChatConnectMode,
+ SettingKeyWeChatConnectScopes,
+ SettingKeyWeChatConnectRedirectURL,
+ SettingKeyWeChatConnectFrontendRedirectURL,
+ }
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err)
+ }
+ return s.parseWeChatConnectOAuthConfig(settings)
+}
+
// GetOverloadCooldownSettings 获取529过载冷却配置
func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings)
@@ -1733,9 +2539,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
+ } else {
+ effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
+ } else {
+ effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
@@ -1864,9 +2674,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go
new file mode 100644
index 00000000..1ff49740
--- /dev/null
+++ b/backend/internal/service/setting_service_auth_source_defaults_test.go
@@ -0,0 +1,138 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type authSourceDefaultsRepoStub struct {
+ values map[string]string
+ updates map[string]string
+}
+
+func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.updates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{
+ values: map[string]string{
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetAuthSourceDefaultSettings(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, 12.5, got.Email.Balance)
+ require.Equal(t, 7, got.Email.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
+ require.False(t, got.Email.GrantOnSignup)
+ require.False(t, got.Email.GrantOnFirstBind)
+ require.Equal(t, 0.0, got.LinuxDo.Balance)
+ require.Equal(t, 5, got.LinuxDo.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
+ require.False(t, got.LinuxDo.GrantOnSignup)
+ require.True(t, got.LinuxDo.GrantOnFirstBind)
+ require.Equal(t, 5, got.OIDC.Concurrency)
+ require.Equal(t, 5, got.WeChat.Concurrency)
+ require.False(t, got.OIDC.GrantOnSignup)
+ require.False(t, got.WeChat.GrantOnSignup)
+ require.True(t, got.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
+ Email: ProviderDefaultGrantSettings{
+ Balance: 1.25,
+ Concurrency: 3,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ LinuxDo: ProviderDefaultGrantSettings{
+ Balance: 2,
+ Concurrency: 4,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ OIDC: ProviderDefaultGrantSettings{
+ Balance: 3,
+ Concurrency: 5,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: true,
+ },
+ WeChat: ProviderDefaultGrantSettings{
+ Balance: 4,
+ Concurrency: 6,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
+ require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
+ require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
+
+ var got []DefaultSubscriptionSetting
+ require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
+}
diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go
index 3809b332..61324204 100644
--- a/backend/internal/service/setting_service_oidc_config_test.go
+++ b/backend/internal/service/setting_service_oidc_config_test.go
@@ -101,3 +101,151 @@ func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testi
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL)
}
+
+func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t *testing.T) {
+ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{})
+
+ got := svc.parseSettings(map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ SettingKeyOIDCConnectUsePKCE: "false",
+ SettingKeyOIDCConnectValidateIDToken: "false",
+ })
+
+ require.False(t, got.OIDCConnectUsePKCE)
+ require.False(t, got.OIDCConnectValidateIDToken)
+}
+
+func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
+ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ UsePKCE: true,
+ UsePKCEExplicit: true,
+ ValidateIDToken: true,
+ ValidateIDTokenExplicit: true,
+ },
+ })
+
+ got := svc.parseSettings(map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ })
+
+ require.True(t, got.OIDCConnectUsePKCE)
+ require.True(t, got.OIDCConnectValidateIDToken)
+}
+
+func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
+ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ UsePKCE: true,
+ ValidateIDToken: true,
+ },
+ })
+
+ got := svc.parseSettings(map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ })
+
+ require.True(t, got.OIDCConnectUsePKCE)
+ require.True(t, got.OIDCConnectValidateIDToken)
+}
+
+func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
+ cfg := &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ },
+ }
+
+ repo := &settingOIDCRepoStub{values: map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ SettingKeyOIDCConnectUsePKCE: "false",
+ SettingKeyOIDCConnectValidateIDToken: "false",
+ }}
+ svc := NewSettingService(repo, cfg)
+
+ got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.False(t, got.UsePKCE)
+ require.False(t, got.ValidateIDToken)
+}
+
+func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
+ cfg := &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ UsePKCEExplicit: true,
+ ValidateIDToken: true,
+ ValidateIDTokenExplicit: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ }
+
+ repo := &settingOIDCRepoStub{values: map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ }}
+ svc := NewSettingService(repo, cfg)
+
+ got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.UsePKCE)
+ require.True(t, got.ValidateIDToken)
+}
+
+func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
+ cfg := &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ }
+
+ repo := &settingOIDCRepoStub{values: map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ }}
+ svc := NewSettingService(repo, cfg)
+
+ got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.UsePKCE)
+ require.True(t, got.ValidateIDToken)
+}
diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go
index 5cf1e860..1ecd4e6f 100644
--- a/backend/internal/service/setting_service_public_test.go
+++ b/backend/internal/service/setting_service_public_test.go
@@ -77,3 +77,77 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}
+
+func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ repo := &settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
+ svc := NewSettingService(&settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-mp-app",
+ SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectOpenEnabled: "true",
+ SettingKeyWeChatConnectMPEnabled: "true",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.WeChatOAuthEnabled)
+ require.True(t, settings.WeChatOAuthOpenEnabled)
+ require.True(t, settings.WeChatOAuthMPEnabled)
+}
+
+func TestSettingService_GetPublicSettings_DoesNotExposeMobileOnlyWeChatAsWebOAuthAvailable(t *testing.T) {
+ svc := NewSettingService(&settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectMobileEnabled: "true",
+ SettingKeyWeChatConnectMode: "mobile",
+ SettingKeyWeChatConnectMobileAppID: "wx-mobile-app",
+ SettingKeyWeChatConnectMobileAppSecret: "wx-mobile-secret",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.False(t, settings.WeChatOAuthEnabled)
+ require.False(t, settings.WeChatOAuthOpenEnabled)
+ require.False(t, settings.WeChatOAuthMPEnabled)
+ require.True(t, settings.WeChatOAuthMobileEnabled)
+}
+
+func TestSettingService_GetPublicSettings_FallsBackToConfigForWeChatOAuthCapabilities(t *testing.T) {
+ svc := NewSettingService(&settingPublicRepoStub{values: map[string]string{}}, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.WeChatOAuthEnabled)
+ require.True(t, settings.WeChatOAuthOpenEnabled)
+ require.False(t, settings.WeChatOAuthMPEnabled)
+ require.False(t, settings.WeChatOAuthMobileEnabled)
+}
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
index e62218b4..9dc0ca59 100644
--- a/backend/internal/service/setting_service_update_test.go
+++ b/backend/internal/service/setting_service_update_test.go
@@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions])
}
+
+func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "alipay",
+ PaymentVisibleMethodWxpaySource: "easypay",
+ PaymentVisibleMethodAlipayEnabled: true,
+ PaymentVisibleMethodWxpayEnabled: false,
+ OpenAIAdvancedSchedulerEnabled: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey])
+}
+
+func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "not-a-provider",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err))
+ require.Nil(t, repo.updates)
+}
diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go
new file mode 100644
index 00000000..a2de614b
--- /dev/null
+++ b/backend/internal/service/setting_service_wechat_config_test.go
@@ -0,0 +1,162 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type settingWeChatRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if value, ok := s.values[key]; ok {
+ return value, nil
+ }
+ return "", ErrSettingNotFound
+}
+
+func (s *settingWeChatRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingWeChatRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) {
+ repo := &settingWeChatRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-db-app",
+ SettingKeyWeChatConnectAppSecret: "wx-db-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectOpenEnabled: "true",
+ SettingKeyWeChatConnectMPEnabled: "true",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.Equal(t, "wx-db-app", got.AppIDForMode("mp"))
+ require.Equal(t, "wx-db-secret", got.AppSecretForMode("mp"))
+ require.True(t, got.OpenEnabled)
+ require.True(t, got.MPEnabled)
+ require.Equal(t, "mp", got.Mode)
+ require.Equal(t, "snsapi_base", got.Scopes)
+ require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL)
+ require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL)
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_FallsBackToConfigWhenDatabaseEmpty(t *testing.T) {
+ repo := &settingWeChatRepoStub{values: map[string]string{}}
+ svc := NewSettingService(repo, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ MPEnabled: true,
+ Mode: "open",
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ MPAppID: "wx-mp-config",
+ MPAppSecret: "wx-mp-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.True(t, got.OpenEnabled)
+ require.True(t, got.MPEnabled)
+ require.Equal(t, "wx-open-config", got.AppIDForMode("open"))
+ require.Equal(t, "wx-open-secret", got.AppSecretForMode("open"))
+ require.Equal(t, "wx-mp-config", got.AppIDForMode("mp"))
+ require.Equal(t, "wx-mp-secret", got.AppSecretForMode("mp"))
+ require.Equal(t, "/auth/wechat/config-callback", got.FrontendRedirectURL)
+ require.Empty(t, got.RedirectURL)
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_IgnoresSyntheticDisabledCapabilitiesFromMigration118(t *testing.T) {
+ repo := &settingWeChatRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectOpenEnabled: "false",
+ SettingKeyWeChatConnectMPEnabled: "false",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ MPEnabled: true,
+ Mode: "open",
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ MPAppID: "wx-mp-config",
+ MPAppSecret: "wx-mp-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.True(t, got.OpenEnabled)
+ require.True(t, got.MPEnabled)
+ require.Equal(t, "wx-open-config", got.AppIDForMode("open"))
+ require.Equal(t, "wx-mp-config", got.AppIDForMode("mp"))
+}
+
+func TestSettingService_ParseSettings_FallsBackToConfigForWeChatAdminView(t *testing.T) {
+ svc := NewSettingService(&settingWeChatRepoStub{values: map[string]string{}}, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ Mode: "open",
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ got := svc.parseSettings(map[string]string{})
+ require.True(t, got.WeChatConnectEnabled)
+ require.True(t, got.WeChatConnectOpenEnabled)
+ require.Equal(t, "wx-open-config", got.WeChatConnectOpenAppID)
+ require.True(t, got.WeChatConnectOpenAppSecretConfigured)
+ require.Equal(t, "/auth/wechat/config-callback", got.WeChatConnectFrontendRedirectURL)
+ require.Equal(t, "open", got.WeChatConnectMode)
+ require.Equal(t, "snsapi_login", got.WeChatConnectScopes)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index ab2eb274..d2ef8fae 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -1,5 +1,16 @@
package service
+import "strings"
+
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
@@ -31,6 +42,28 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool
+ WeChatConnectAppID string
+ WeChatConnectAppSecret string
+ WeChatConnectAppSecretConfigured bool
+ WeChatConnectOpenAppID string
+ WeChatConnectOpenAppSecret string
+ WeChatConnectOpenAppSecretConfigured bool
+ WeChatConnectMPAppID string
+ WeChatConnectMPAppSecret string
+ WeChatConnectMPAppSecretConfigured bool
+ WeChatConnectMobileAppID string
+ WeChatConnectMobileAppSecret string
+ WeChatConnectMobileAppSecretConfigured bool
+ WeChatConnectOpenEnabled bool
+ WeChatConnectMPEnabled bool
+ WeChatConnectMobileEnabled bool
+ WeChatConnectMode string
+ WeChatConnectScopes string
+ WeChatConnectRedirectURL string
+ WeChatConnectFrontendRedirectURL string
+
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool
OIDCConnectProviderName string
@@ -110,6 +143,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string
+ PaymentVisibleMethodWxpaySource string
+ PaymentVisibleMethodAlipayEnabled bool
+ PaymentVisibleMethodWxpayEnabled bool
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool
+
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
@@ -128,6 +170,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
+ ForceEmailOnThirdPartySignup bool
RegistrationEmailSuffixWhitelist []string
PromoCodeEnabled bool
PasswordResetEnabled bool
@@ -151,12 +194,16 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
- LinuxDoOAuthEnabled bool
- BackendModeEnabled bool
- PaymentEnabled bool
- OIDCOAuthEnabled bool
- OIDCOAuthProviderName string
- Version string
+ LinuxDoOAuthEnabled bool
+ WeChatOAuthEnabled bool
+ WeChatOAuthOpenEnabled bool
+ WeChatOAuthMPEnabled bool
+ WeChatOAuthMobileEnabled bool
+ BackendModeEnabled bool
+ PaymentEnabled bool
+ OIDCOAuthEnabled bool
+ OIDCOAuthProviderName string
+ Version string
BalanceLowNotifyEnabled bool
AccountQuotaNotifyEnabled bool
@@ -164,6 +211,66 @@ type PublicSettings struct {
BalanceLowNotifyRechargeURL string
}
+type WeChatConnectOAuthConfig struct {
+ Enabled bool
+ LegacyAppID string
+ LegacyAppSecret string
+ OpenAppID string
+ OpenAppSecret string
+ MPAppID string
+ MPAppSecret string
+ MobileAppID string
+ MobileAppSecret string
+ OpenEnabled bool
+ MPEnabled bool
+ MobileEnabled bool
+ Mode string
+ Scopes string
+ RedirectURL string
+ FrontendRedirectURL string
+}
+
+func (cfg WeChatConnectOAuthConfig) SupportsMode(mode string) bool {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return cfg.MPEnabled
+ case "mobile":
+ return cfg.MobileEnabled
+ default:
+ return cfg.OpenEnabled
+ }
+}
+
+func (cfg WeChatConnectOAuthConfig) ScopeForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return normalizeWeChatConnectScopeSetting(cfg.Scopes, "mp")
+ case "mobile":
+ return ""
+ }
+ return defaultWeChatConnectScopeForMode("open")
+}
+
+func (cfg WeChatConnectOAuthConfig) AppIDForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return strings.TrimSpace(firstNonEmpty(cfg.MPAppID, cfg.LegacyAppID))
+ case "mobile":
+ return strings.TrimSpace(firstNonEmpty(cfg.MobileAppID, cfg.LegacyAppID))
+ }
+ return strings.TrimSpace(firstNonEmpty(cfg.OpenAppID, cfg.LegacyAppID))
+}
+
+func (cfg WeChatConnectOAuthConfig) AppSecretForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return strings.TrimSpace(firstNonEmpty(cfg.MPAppSecret, cfg.LegacyAppSecret))
+ case "mobile":
+ return strings.TrimSpace(firstNonEmpty(cfg.MobileAppSecret, cfg.LegacyAppSecret))
+ }
+ return strings.TrimSpace(firstNonEmpty(cfg.OpenAppSecret, cfg.LegacyAppSecret))
+}
+
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
diff --git a/backend/internal/service/sql_errors.go b/backend/internal/service/sql_errors.go
new file mode 100644
index 00000000..7c0155a4
--- /dev/null
+++ b/backend/internal/service/sql_errors.go
@@ -0,0 +1,14 @@
+package service
+
+import (
+ "database/sql"
+ "errors"
+ "strings"
+)
+
+func isSQLNoRowsError(err error) bool {
+ if err == nil {
+ return false
+ }
+ return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
+}
diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go
index e7ef8982..11ace7bd 100644
--- a/backend/internal/service/sticky_session_test.go
+++ b/backend/internal/service/sticky_session_test.go
@@ -15,20 +15,8 @@ import (
"github.com/stretchr/testify/require"
)
-// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
-// 验证在以下情况下是否正确判断需要清理粘性会话:
-// - nil 账号:不清理(返回 false)
-// - 状态为错误或禁用:清理
-// - 不可调度:清理
-// - 临时不可调度且未过期:清理
-// - 临时不可调度已过期:不清理
-// - 正常可调度状态:不清理
-// - 模型限流(任意时长):清理
-//
-// TestShouldClearStickySession tests the sticky session clearing logic.
-// Verifies correct behavior for various account states including:
-// nil account, error/disabled status, unschedulable, temporary unschedulable,
-// and model rate limiting scenarios.
+// TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation
+// plus model-level rate limiting.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) {
requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响
},
+ {
+ name: "apikey quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "oauth quota exceeded not cleared",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: false,
+ },
+ {
+ name: "overloaded account",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ OverloadUntil: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "account-level rate limited",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
}
for _, tt := range tests {
diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go
index 5192fe3d..052739ed 100644
--- a/backend/internal/service/totp_service.go
+++ b/backend/internal/service/totp_service.go
@@ -58,9 +58,15 @@ type TotpSetupSession struct {
// TotpLoginSession represents a pending 2FA login session
type TotpLoginSession struct {
- UserID int64
- Email string
- TokenExpiry time.Time
+ UserID int64
+ Email string
+ TokenExpiry time.Time
+ PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"`
+}
+
+type PendingOAuthBindLoginSession struct {
+ PendingSessionToken string `json:"pending_session_token,omitempty"`
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
}
// TotpStatus represents the TOTP status for a user
@@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string)
// CreateLoginSession creates a temporary login session for 2FA
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
+ return s.createLoginSession(ctx, userID, email, nil)
+}
+
+// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will
+// finalize a pending OAuth bind after the TOTP code is verified.
+func (s *TotpService) CreatePendingOAuthBindLoginSession(
+ ctx context.Context,
+ userID int64,
+ email string,
+ pendingSessionToken string,
+ browserSessionKey string,
+) (string, error) {
+ return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{
+ PendingSessionToken: pendingSessionToken,
+ BrowserSessionKey: browserSessionKey,
+ })
+}
+
+func (s *TotpService) createLoginSession(
+ ctx context.Context,
+ userID int64,
+ email string,
+ pendingOAuthBind *PendingOAuthBindLoginSession,
+) (string, error) {
// Generate a random temp token
tempToken, err := generateRandomToken(32)
if err != nil {
@@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai
}
session := &TotpLoginSession{
- UserID: userID,
- Email: email,
- TokenExpiry: time.Now().Add(totpLoginTTL),
+ UserID: userID,
+ Email: email,
+ TokenExpiry: time.Now().Add(totpLoginTTL),
+ PendingOAuthBind: pendingOAuthBind,
}
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go
index a0444d52..ddf0e818 100644
--- a/backend/internal/service/upstream_response_limit.go
+++ b/backend/internal/service/upstream_response_limit.go
@@ -12,7 +12,9 @@ import (
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
-const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
+// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes,
+// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。
+const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index 59f8aa6b..9dc13381 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -7,19 +7,31 @@ import (
)
type User struct {
- ID int64
- Email string
- Username string
- Notes string
- PasswordHash string
- Role string
- Balance float64
- Concurrency int
- Status string
- AllowedGroups []int64
- TokenVersion int64 // Incremented on password change to invalidate existing tokens
- CreatedAt time.Time
- UpdatedAt time.Time
+ ID int64
+ Email string
+ Username string
+ Notes string
+ AvatarURL string
+ AvatarSource string
+ AvatarMIME string
+ AvatarByteSize int
+ AvatarSHA256 string
+ PasswordHash string
+ Role string
+ Balance float64
+ Concurrency int
+ Status string
+ AllowedGroups []int64
+ TokenVersion int64 // Incremented on password change to invalidate existing tokens
+ // TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived
+ // value expected in JWT claims and refresh-token state.
+ TokenVersionResolved bool
+ SignupSource string
+ LastLoginAt *time.Time
+ LastActiveAt *time.Time
+ LastUsedAt *time.Time
+ CreatedAt time.Time
+ UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 3490e804..a7279e6a 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -1,30 +1,66 @@
package service
import (
+ "bytes"
"context"
+ "crypto/sha256"
"crypto/subtle"
+ "encoding/base64"
+ "encoding/hex"
"fmt"
- "log/slog"
- "strings"
- "time"
-
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "image"
+ "image/color"
+ stddraw "image/draw"
+ _ "image/gif"
+ "image/jpeg"
+ _ "image/png"
+ "log/slog"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ xdraw "golang.org/x/image/draw"
+ "golang.org/x/sync/singleflight"
)
var (
- ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
- ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
- ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
- ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
+ ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
+ ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
+ ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
+ ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
+ ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
+ ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
+ ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
+ ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
+ ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
+ ErrIdentityUnbindLastMethod = infraerrors.Conflict(
+ "IDENTITY_UNBIND_LAST_METHOD",
+ "bind another sign-in method before unbinding this provider",
+ )
)
const (
- maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxInlineAvatarBytes = 100 * 1024
+ targetAvatarBytes = 20 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
+
+ defaultUserIdentityRedirect = "/settings/profile"
+ userLastActiveMinTouch = 10 * time.Minute
+ userLastActiveFailBackoff = 30 * time.Second
+)
+
+var (
+ avatarScaleSteps = []float64{1, 0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36}
+ avatarQualitySteps = []int{88, 80, 72, 64, 56, 48, 40, 32}
)
// UserListFilters contains all filter options for listing users
@@ -47,9 +83,15 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
+ GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
+ UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
+ GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error)
+ GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error)
+ UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -60,6 +102,8 @@ type UserRepository interface {
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
+ ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
+ UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
@@ -67,15 +111,90 @@ type UserRepository interface {
DisableTotp(ctx context.Context, userID int64) error
}
+type UserAuthIdentityRecord struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+ VerifiedAt *time.Time
+ Issuer *string
+ Metadata map[string]any
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+type UserIdentitySummary struct {
+ Provider string `json:"provider"`
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name,omitempty"`
+ AvatarURL string `json:"-"`
+ SubjectHint string `json:"subject_hint,omitempty"`
+ ProviderKey string `json:"provider_key,omitempty"`
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ BindStartPath string `json:"bind_start_path,omitempty"`
+ CanBind bool `json:"can_bind"`
+ CanUnbind bool `json:"can_unbind"`
+ NoteKey string `json:"note_key,omitempty"`
+ Note string `json:"note,omitempty"`
+}
+
+type UserIdentitySummarySet struct {
+ Email UserIdentitySummary `json:"email"`
+ LinuxDo UserIdentitySummary `json:"linuxdo"`
+ OIDC UserIdentitySummary `json:"oidc"`
+ WeChat UserIdentitySummary `json:"wechat"`
+}
+
+type StartUserIdentityBindingRequest struct {
+ Provider string
+ RedirectTo string
+}
+
+type StartUserIdentityBindingResult struct {
+ Provider string `json:"provider"`
+ AuthorizeURL string `json:"authorize_url"`
+ Method string `json:"method"`
+ UseBrowserRedirect bool `json:"use_browser_redirect"`
+}
+
+const (
+ userIdentityNoteEmailManagedFromProfile = "profile.authBindings.notes.emailManagedFromProfile"
+ userIdentityNoteCanUnbind = "profile.authBindings.notes.canUnbind"
+ userIdentityNoteBindAnotherBeforeUnbind = "profile.authBindings.notes.bindAnotherBeforeUnbind"
+)
+
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type UserAvatar struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
+type UpsertUserAvatarInput struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
+type userProfileIdentityTxRunner interface {
+ WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -88,6 +207,8 @@ type UserService struct {
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
+ lastActiveTouchL1 sync.Map
+ lastActiveTouchSF singleflight.Group
}
// NewUserService 创建用户服务实例
@@ -115,14 +236,176 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ normalizeLoadedUserTokenVersion(user)
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
+func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) {
+ if user == nil {
+ var err error
+ user, err = s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err)
+ }
+ }
+
+ records, err := s.listUserAuthIdentities(ctx, userID)
+ if err != nil {
+ return UserIdentitySummarySet{}, err
+ }
+
+ summaries := UserIdentitySummarySet{
+ Email: s.buildEmailIdentitySummary(user, records),
+ LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
+ OIDC: s.buildProviderIdentitySummary("oidc", user, records),
+ WeChat: s.buildProviderIdentitySummary("wechat", user, records),
+ }
+
+ s.applyExplicitProviderAvailability(ctx, &summaries)
+ return summaries, nil
+}
+
+func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, summaries *UserIdentitySummarySet) {
+ if s == nil || summaries == nil || s.settingRepo == nil {
+ return
+ }
+
+ settings, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingKeyLinuxDoConnectEnabled,
+ SettingKeyOIDCConnectEnabled,
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectOpenEnabled,
+ SettingKeyWeChatConnectMPEnabled,
+ SettingKeyWeChatConnectMobileEnabled,
+ SettingKeyWeChatConnectMode,
+ })
+ if err != nil {
+ return
+ }
+
+ if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
+ disableIdentityBindAction(&summaries.LinuxDo)
+ }
+ if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
+ disableIdentityBindAction(&summaries.OIDC)
+ }
+ if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok && strings.TrimSpace(raw) != "" {
+ if raw != "true" {
+ disableIdentityBindAction(&summaries.WeChat)
+ return
+ }
+ openEnabled, mpEnabled, _ := parseWeChatConnectCapabilitySettings(settings, true, settings[SettingKeyWeChatConnectMode])
+ if !openEnabled && !mpEnabled {
+ disableIdentityBindAction(&summaries.WeChat)
+ }
+ }
+}
+
+func disableIdentityBindAction(summary *UserIdentitySummary) {
+ if summary == nil || summary.Bound {
+ return
+ }
+ summary.CanBind = false
+ summary.BindStartPath = ""
+}
+
+func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) {
+ provider := normalizeUserIdentityProvider(req.Provider)
+ if provider == "" {
+ return nil, ErrIdentityProviderInvalid
+ }
+
+ authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo)
+ if err != nil {
+ return nil, err
+ }
+
+ return &StartUserIdentityBindingResult{
+ Provider: provider,
+ AuthorizeURL: authorizeURL,
+ Method: "GET",
+ UseBrowserRedirect: true,
+ }, nil
+}
+
+func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) {
+ user, _, err := s.UnbindUserAuthProviderWithResult(ctx, userID, provider)
+ return user, err
+}
+
+func (s *UserService) UnbindUserAuthProviderWithResult(ctx context.Context, userID int64, provider string) (*User, bool, error) {
+ provider = normalizeUserIdentityProvider(provider)
+ if provider == "" || provider == "email" {
+ return nil, false, ErrIdentityProviderInvalid
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, false, fmt.Errorf("get user: %w", err)
+ }
+
+ records, err := s.listUserAuthIdentities(ctx, userID)
+ if err != nil {
+ return nil, false, err
+ }
+ if len(filterUserAuthIdentities(records, provider)) == 0 {
+ return user, false, nil
+ }
+ if !s.canUnbindProvider(provider, user, records) {
+ return nil, false, ErrIdentityUnbindLastMethod
+ }
+
+ if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil {
+ return nil, false, err
+ }
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+
+ updatedUser, err := s.GetProfile(ctx, userID)
+ if err != nil {
+ return nil, false, err
+ }
+ return updatedUser, true, nil
+}
+
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
+ if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok {
+ var (
+ updated *User
+ oldConcurrency int
+ )
+ if err := txRunner.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ var err error
+ updated, oldConcurrency, err = s.updateProfile(txCtx, userID, req)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+ if s.authCacheInvalidator != nil && updated != nil && updated.Concurrency != oldConcurrency {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ return updated, nil
+ }
+
+ updated, oldConcurrency, err := s.updateProfile(ctx, userID, req)
+ if err != nil {
+ return nil, err
+ }
+ if s.authCacheInvalidator != nil && updated.Concurrency != oldConcurrency {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ return updated, nil
+}
+
+func (s *UserService) updateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, int, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
- return nil, fmt.Errorf("get user: %w", err)
+ return nil, 0, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
@@ -131,10 +414,10 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
- return nil, fmt.Errorf("check email exists: %w", err)
+ return nil, oldConcurrency, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
- return nil, ErrEmailExists
+ return nil, oldConcurrency, ErrEmailExists
}
user.Email = *req.Email
}
@@ -143,6 +426,14 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
+ if req.AvatarURL != nil {
+ avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL)
+ if err != nil {
+ return nil, oldConcurrency, err
+ }
+ applyUserAvatar(user, avatar)
+ }
+
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
@@ -159,13 +450,465 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
}
if err := s.userRepo.Update(ctx, user); err != nil {
- return nil, fmt.Errorf("update user: %w", err)
- }
- if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
- s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ return nil, oldConcurrency, fmt.Errorf("update user: %w", err)
}
- return user, nil
+ return user, oldConcurrency, nil
+}
+
+func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) {
+ avatarValue := strings.TrimSpace(raw)
+ if avatarValue == "" {
+ if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
+ return nil, fmt.Errorf("delete avatar: %w", err)
+ }
+ return nil, nil
+ }
+
+ avatarInput, err := normalizeUserAvatarInput(avatarValue)
+ if err != nil {
+ return nil, err
+ }
+
+ avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
+ if err != nil {
+ return nil, fmt.Errorf("upsert avatar: %w", err)
+ }
+ return avatar, nil
+}
+
+func applyUserAvatar(user *User, avatar *UserAvatar) {
+ if user == nil {
+ return
+ }
+ if avatar == nil {
+ user.AvatarURL = ""
+ user.AvatarSource = ""
+ user.AvatarMIME = ""
+ user.AvatarByteSize = 0
+ user.AvatarSHA256 = ""
+ return
+ }
+
+ user.AvatarURL = avatar.URL
+ user.AvatarSource = avatar.StorageProvider
+ user.AvatarMIME = avatar.ContentType
+ user.AvatarByteSize = avatar.ByteSize
+ user.AvatarSHA256 = avatar.SHA256
+}
+
+func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.HasPrefix(raw, "data:") {
+ return normalizeInlineUserAvatarInput(raw)
+ }
+
+ parsed, err := url.Parse(raw)
+ if err != nil || parsed == nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.TrimSpace(parsed.Host) == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ return UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: raw,
+ }, nil
+}
+
+func ValidateUserAvatar(raw string) error {
+ _, err := normalizeUserAvatarInput(raw)
+ return err
+}
+
+func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ body := strings.TrimPrefix(raw, "data:")
+ meta, encoded, ok := strings.Cut(body, ",")
+ if !ok {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ meta = strings.TrimSpace(meta)
+ encoded = strings.TrimSpace(encoded)
+ if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
+ if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
+ return UpsertUserAvatarInput{}, ErrAvatarNotImage
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if len(decoded) > maxInlineAvatarBytes {
+ return UpsertUserAvatarInput{}, ErrAvatarTooLarge
+ }
+
+ if len(decoded) > targetAvatarBytes {
+ decoded, contentType, err = compressInlineAvatar(decoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, err
+ }
+ raw = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(decoded)
+ }
+
+ sum := sha256.Sum256(decoded)
+ return UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: raw,
+ ContentType: contentType,
+ ByteSize: len(decoded),
+ SHA256: hex.EncodeToString(sum[:]),
+ }, nil
+}
+
+func compressInlineAvatar(decoded []byte) ([]byte, string, error) {
+ src, _, err := image.Decode(bytes.NewReader(decoded))
+ if err != nil {
+ return nil, "", ErrAvatarInvalid
+ }
+
+ srcBounds := src.Bounds()
+ if srcBounds.Empty() {
+ return nil, "", ErrAvatarInvalid
+ }
+
+ for _, scale := range avatarScaleSteps {
+ width := max(1, int(float64(srcBounds.Dx())*scale))
+ height := max(1, int(float64(srcBounds.Dy())*scale))
+ dst := image.NewRGBA(image.Rect(0, 0, width, height))
+ stddraw.Draw(dst, dst.Bounds(), &image.Uniform{C: color.White}, image.Point{}, stddraw.Src)
+ xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, srcBounds, stddraw.Over, nil)
+
+ for _, quality := range avatarQualitySteps {
+ var buf bytes.Buffer
+ if err := jpeg.Encode(&buf, dst, &jpeg.Options{Quality: quality}); err != nil {
+ return nil, "", ErrAvatarInvalid
+ }
+ if buf.Len() <= targetAvatarBytes {
+ return buf.Bytes(), "image/jpeg", nil
+ }
+ }
+ }
+
+ return nil, "", ErrAvatarTooLarge
+}
+
+func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
+ summary := UserIdentitySummary{
+ Provider: "email",
+ CanBind: false,
+ CanUnbind: false,
+ NoteKey: userIdentityNoteEmailManagedFromProfile,
+ Note: "Primary account email is managed from the profile form.",
+ }
+ if user == nil {
+ return summary
+ }
+
+ filtered := filterUserAuthIdentities(records, "email")
+ if len(filtered) > 0 {
+ primary := selectPrimaryUserAuthIdentity(filtered)
+ email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email"))
+ if email == "" {
+ email = strings.TrimSpace(primary.ProviderSubject)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(user.Email)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(primary.ProviderKey)
+ }
+
+ summary.Bound = true
+ summary.BoundCount = len(filtered)
+ summary.DisplayName = email
+ summary.SubjectHint = maskEmailIdentity(email)
+ summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
+ summary.VerifiedAt = primary.VerifiedAt
+ return summary
+ }
+
+ // Compatibility fallback for legacy normal-email users that predate auth_identities backfill.
+ email := strings.TrimSpace(user.Email)
+ if email == "" || isReservedEmail(email) {
+ return summary
+ }
+ summary.Bound = true
+ summary.BoundCount = 1
+ summary.DisplayName = email
+ summary.SubjectHint = maskEmailIdentity(email)
+ summary.ProviderKey = "email"
+ return summary
+}
+
+func (s *UserService) buildProviderIdentitySummary(provider string, user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
+ summary := UserIdentitySummary{
+ Provider: provider,
+ CanUnbind: false,
+ }
+ filtered := filterUserAuthIdentities(records, provider)
+ if len(filtered) == 0 {
+ summary.CanBind = true
+ bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "")
+ if err == nil {
+ summary.BindStartPath = bindStartPath
+ }
+ return summary
+ }
+
+ primary := selectPrimaryUserAuthIdentity(filtered)
+ summary.Bound = true
+ summary.BoundCount = len(filtered)
+ summary.DisplayName = userAuthIdentityDisplayName(primary)
+ summary.AvatarURL = strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "avatar_url", "suggested_avatar_url", "headimgurl"))
+ summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
+ summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
+ summary.VerifiedAt = primary.VerifiedAt
+ summary.CanUnbind = s.canUnbindProvider(provider, user, records)
+ if summary.CanUnbind {
+ summary.NoteKey = userIdentityNoteCanUnbind
+ summary.Note = "You can unbind this sign-in method."
+ } else {
+ summary.NoteKey = userIdentityNoteBindAnotherBeforeUnbind
+ summary.Note = "Bind another sign-in method before unbinding."
+ }
+ return summary
+}
+
+func (s *UserService) canUnbindProvider(provider string, user *User, records []UserAuthIdentityRecord) bool {
+ if provider == "" || provider == "email" || len(filterUserAuthIdentities(records, provider)) == 0 {
+ return false
+ }
+
+ if s.canUseEmailAsSignInMethod(user, records) {
+ return true
+ }
+
+ for _, candidate := range []string{"linuxdo", "oidc", "wechat"} {
+ if candidate == provider {
+ continue
+ }
+ if len(filterUserAuthIdentities(records, candidate)) > 0 {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (s *UserService) canUseEmailAsSignInMethod(user *User, records []UserAuthIdentityRecord) bool {
+ if user == nil {
+ return false
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return false
+ }
+
+ if emailSignupSourceAllowsLogin(user.SignupSource) {
+ return true
+ }
+
+ for _, record := range filterUserAuthIdentities(records, "email") {
+ if emailIdentitySupportsSignIn(record) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func emailSignupSourceAllowsLogin(signupSource string) bool {
+ signupSource = strings.ToLower(strings.TrimSpace(signupSource))
+ return signupSource == "" || signupSource == "email"
+}
+
+func emailIdentitySupportsSignIn(record UserAuthIdentityRecord) bool {
+ source := strings.TrimSpace(firstStringIdentityValue(record.Metadata, "source"))
+ switch source {
+ case "auth_service_email_bind", "auth_service_login_backfill", "auth_service_dual_write":
+ return true
+ default:
+ return false
+ }
+}
+
+func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ if userID <= 0 || s == nil || s.userRepo == nil {
+ return nil, nil
+ }
+ return s.userRepo.ListUserAuthIdentities(ctx, userID)
+}
+
+func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) {
+ provider = normalizeUserIdentityProvider(provider)
+ if provider == "" || provider == "email" {
+ return "", ErrIdentityProviderInvalid
+ }
+
+ redirectTo, err := normalizeUserIdentityRedirect(redirectTo)
+ if err != nil {
+ return "", err
+ }
+
+ path := ""
+ switch provider {
+ case "linuxdo":
+ path = "/api/v1/auth/oauth/linuxdo/bind/start"
+ case "oidc":
+ path = "/api/v1/auth/oauth/oidc/bind/start"
+ case "wechat":
+ path = "/api/v1/auth/oauth/wechat/bind/start"
+ default:
+ return "", ErrIdentityProviderInvalid
+ }
+
+ query := url.Values{}
+ query.Set("redirect", redirectTo)
+ query.Set("intent", "bind_current_user")
+ return path + "?" + query.Encode(), nil
+}
+
+func normalizeUserIdentityProvider(provider string) string {
+ switch strings.ToLower(strings.TrimSpace(provider)) {
+ case "linuxdo":
+ return "linuxdo"
+ case "oidc":
+ return "oidc"
+ case "wechat":
+ return "wechat"
+ case "email":
+ return "email"
+ default:
+ return ""
+ }
+}
+
+func normalizeUserIdentityRedirect(raw string) (string, error) {
+ redirect := strings.TrimSpace(raw)
+ if redirect == "" {
+ return defaultUserIdentityRedirect, nil
+ }
+ if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
+ return "", ErrIdentityRedirectInvalid
+ }
+ return redirect, nil
+}
+
+func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord {
+ if len(records) == 0 {
+ return nil
+ }
+ filtered := make([]UserAuthIdentityRecord, 0, len(records))
+ for _, record := range records {
+ if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) {
+ filtered = append(filtered, record)
+ }
+ }
+ return filtered
+}
+
+func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord {
+ if len(records) == 0 {
+ return UserAuthIdentityRecord{}
+ }
+ sort.SliceStable(records, func(i, j int) bool {
+ left := userAuthIdentitySortTime(records[i])
+ right := userAuthIdentitySortTime(records[j])
+ if !left.Equal(right) {
+ return left.After(right)
+ }
+ return records[i].ProviderKey < records[j].ProviderKey
+ })
+ return records[0]
+}
+
+func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time {
+ if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() {
+ return record.VerifiedAt.UTC()
+ }
+ if !record.UpdatedAt.IsZero() {
+ return record.UpdatedAt.UTC()
+ }
+ if !record.CreatedAt.IsZero() {
+ return record.CreatedAt.UTC()
+ }
+ return time.Time{}
+}
+
+func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string {
+ if displayName := firstStringIdentityValue(record.Metadata,
+ "display_name",
+ "suggested_display_name",
+ "username",
+ "name",
+ "nickname",
+ "email",
+ ); displayName != "" {
+ return displayName
+ }
+ if subject := strings.TrimSpace(record.ProviderSubject); subject != "" {
+ return subject
+ }
+ return strings.TrimSpace(record.ProviderType)
+}
+
+func firstStringIdentityValue(values map[string]any, keys ...string) string {
+ for _, key := range keys {
+ raw, ok := values[key]
+ if !ok {
+ continue
+ }
+ switch value := raw.(type) {
+ case string:
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ case fmt.Stringer:
+ if trimmed := strings.TrimSpace(value.String()); trimmed != "" {
+ return trimmed
+ }
+ }
+ }
+ return ""
+}
+
+func maskEmailIdentity(email string) string {
+ local, domain, ok := strings.Cut(strings.TrimSpace(email), "@")
+ if !ok || local == "" || domain == "" {
+ return maskOpaqueIdentity(email)
+ }
+ runes := []rune(local)
+ if len(runes) == 1 {
+ return string(runes[0]) + "***@" + domain
+ }
+ return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain
+}
+
+func maskOpaqueIdentity(value string) string {
+ value = strings.TrimSpace(value)
+ runes := []rune(value)
+ switch {
+ case len(runes) == 0:
+ return ""
+ case len(runes) <= 4:
+ return string(runes[0]) + "***"
+ case len(runes) <= 8:
+ return string(runes[:2]) + "***" + string(runes[len(runes)-1:])
+ default:
+ return string(runes[:3]) + "***" + string(runes[len(runes)-3:])
+ }
}
// ChangePassword 修改密码
@@ -202,9 +945,94 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ normalizeLoadedUserTokenVersion(user)
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
+func normalizeLoadedUserTokenVersion(user *User) {
+ if user == nil || user.TokenVersionResolved {
+ return
+ }
+ user.TokenVersion = resolvedTokenVersion(user)
+ user.TokenVersionResolved = true
+}
+
+// TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。
+// 该操作为尽力而为,不应中断正常请求。
+func (s *UserService) TouchLastActive(ctx context.Context, userID int64) {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ slog.Debug("skip touch user last active after load failure", "user_id", userID, "error", err)
+ return
+ }
+ s.TouchLastActiveForUser(ctx, user)
+}
+
+// TouchLastActiveForUser 使用已加载的用户信息更新 last_active_at,避免重复读取数据库。
+func (s *UserService) TouchLastActiveForUser(ctx context.Context, user *User) {
+ if s == nil || s.userRepo == nil || user == nil || user.ID <= 0 {
+ return
+ }
+
+ now := time.Now()
+ if userLastActiveFresh(user.LastActiveAt, now) {
+ return
+ }
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
+ return
+ }
+ }
+
+ _, err, _ := s.lastActiveTouchSF.Do(strconv.FormatInt(user.ID, 10), func() (any, error) {
+ latest := time.Now()
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
+ return nil, nil
+ }
+ }
+ if userLastActiveFresh(user.LastActiveAt, latest) {
+ return nil, nil
+ }
+ if err := s.userRepo.UpdateUserLastActiveAt(ctx, user.ID, latest); err != nil {
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveFailBackoff))
+ return nil, fmt.Errorf("touch user last active: %w", err)
+ }
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveMinTouch))
+ return nil, nil
+ })
+ if err != nil {
+ slog.Warn("touch user last active failed", "user_id", user.ID, "error", err)
+ }
+}
+
+func userLastActiveFresh(lastActiveAt *time.Time, now time.Time) bool {
+ if lastActiveAt == nil {
+ return false
+ }
+ return now.Before(lastActiveAt.Add(userLastActiveMinTouch))
+}
+
+func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
+ if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
+ return nil
+ }
+
+ avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
+ if err != nil {
+ return err
+ }
+ applyUserAvatar(user, avatar)
+ return nil
+}
+
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
new file mode 100644
index 00000000..702b3b1a
--- /dev/null
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -0,0 +1,34 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 19,
+ Email: "profile-before@example.com",
+ Username: "tester",
+ Concurrency: 2,
+ },
+ replaceErr: context.DeadlineExceeded,
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ newEmail := "profile-after@example.com"
+ updated, err := svc.UpdateProfile(context.Background(), 19, UpdateProfileRequest{
+ Email: &newEmail,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, newEmail, updated.Email)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Empty(t, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index a998d5f4..ff55c2a5 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -3,8 +3,14 @@
package service
import (
+ "bytes"
"context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
"errors"
+ "image"
+ "image/png"
"sync"
"sync/atomic"
"testing"
@@ -17,16 +23,159 @@ import (
// --- mock: UserRepository ---
type mockUserRepo struct {
- updateBalanceErr error
- updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ updateBalanceErr error
+ updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ getByIDUser *User
+ getByIDErr error
+ identities []UserAuthIdentityRecord
+ unbindIdentityErr error
+ unboundProviders []string
+ updateLastActiveErr error
+ updateLastActiveUserIDs []int64
+ updateLastActiveAt []time.Time
+ updateFn func(ctx context.Context, user *User) error
+ updateCalls int
+ upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarFn func(ctx context.Context, userID int64) error
+ deleteAvatarIDs []int64
+ getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
+ txCalls int
}
-func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
-func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
+type mockUserRepoTxKey struct{}
+
+type mockUserRepoTxState struct {
+ getByIDUser *User
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarIDs []int64
+}
+
+type mockUserSettingRepo struct {
+ values map[string]string
+}
+
+func (m *mockUserSettingRepo) Get(context.Context, string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (m *mockUserSettingRepo) GetValue(context.Context, string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (m *mockUserSettingRepo) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (m *mockUserSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := m.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (m *mockUserSettingRepo) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (m *mockUserSettingRepo) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (m *mockUserSettingRepo) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
+func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) {
+ if m.getByIDErr != nil {
+ return nil, m.getByIDErr
+ }
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil && txState.getByIDUser != nil {
+ cloned := *txState.getByIDUser
+ return &cloned, nil
+ }
+ if m.getByIDUser != nil {
+ cloned := *m.getByIDUser
+ return &cloned, nil
+ }
+ return &User{}, nil
+}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
-func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
+ m.updateCalls++
+ if m.updateFn != nil {
+ return m.updateFn(ctx, user)
+ }
+ return nil
+}
+func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ if m.getAvatarFn != nil {
+ return m.getAvatarFn(ctx, userID)
+ }
+ return nil, nil
+}
+func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil {
+ txState.upsertAvatarArgs = append(txState.upsertAvatarArgs, input)
+ if txState.getByIDUser != nil {
+ txState.getByIDUser.AvatarURL = input.URL
+ txState.getByIDUser.AvatarSource = input.StorageProvider
+ txState.getByIDUser.AvatarMIME = input.ContentType
+ txState.getByIDUser.AvatarByteSize = input.ByteSize
+ txState.getByIDUser.AvatarSHA256 = input.SHA256
+ }
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+ }
+ m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil {
+ txState.deleteAvatarIDs = append(txState.deleteAvatarIDs, userID)
+ if txState.getByIDUser != nil {
+ txState.getByIDUser.AvatarURL = ""
+ txState.getByIDUser.AvatarSource = ""
+ txState.getByIDUser.AvatarMIME = ""
+ txState.getByIDUser.AvatarByteSize = 0
+ txState.getByIDUser.AvatarSHA256 = ""
+ }
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+ }
+ m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@@ -39,6 +188,11 @@ func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float
}
return m.updateBalanceErr
}
+func (m *mockUserRepo) UpdateUserLastActiveAt(_ context.Context, userID int64, activeAt time.Time) error {
+ m.updateLastActiveUserIDs = append(m.updateLastActiveUserIDs, userID)
+ m.updateLastActiveAt = append(m.updateLastActiveAt, activeAt)
+ return m.updateLastActiveErr
+}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
@@ -46,12 +200,58 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
-func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ out := make([]UserAuthIdentityRecord, len(m.identities))
+ copy(out, m.identities)
+ return out, nil
+}
+func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
+func (m *mockUserRepo) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
+ if m.unbindIdentityErr != nil {
+ return m.unbindIdentityErr
+ }
+ m.unboundProviders = append(m.unboundProviders, provider)
+ filtered := m.identities[:0]
+ for _, identity := range m.identities {
+ if identity.ProviderType == provider {
+ continue
+ }
+ filtered = append(filtered, identity)
+ }
+ m.identities = append([]UserAuthIdentityRecord(nil), filtered...)
+ return nil
+}
+
+func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ m.txCalls++
+ txState := &mockUserRepoTxState{
+ upsertAvatarArgs: append([]UpsertUserAvatarInput(nil), m.upsertAvatarArgs...),
+ deleteAvatarIDs: append([]int64(nil), m.deleteAvatarIDs...),
+ }
+ if m.getByIDUser != nil {
+ userCopy := *m.getByIDUser
+ txState.getByIDUser = &userCopy
+ }
+ err := fn(context.WithValue(ctx, mockUserRepoTxKey{}, txState))
+ if err != nil {
+ return err
+ }
+ m.getByIDUser = txState.getByIDUser
+ m.upsertAvatarArgs = txState.upsertAvatarArgs
+ m.deleteAvatarIDs = txState.deleteAvatarIDs
+ return nil
+}
// --- mock: APIKeyAuthCacheInvalidator ---
@@ -132,6 +332,225 @@ func TestUpdateBalance_Success(t *testing.T) {
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
}
+func TestGetProfileIdentitySummaries_AllowsUnbindWhenAnotherLoginMethodRemains(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 7,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-123456",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 7, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.True(t, summaries.LinuxDo.Bound)
+ require.True(t, summaries.LinuxDo.CanUnbind)
+ require.Equal(t, "linuxdo-handle", summaries.LinuxDo.DisplayName)
+ require.NotEmpty(t, summaries.LinuxDo.SubjectHint)
+}
+
+func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 9,
+ Email: "only-user@linuxdo-connect.invalid",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-only-subject",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ _, err := svc.UnbindUserAuthProvider(context.Background(), 9, "linuxdo")
+
+ require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
+ require.Empty(t, repo.unboundProviders)
+}
+
+func TestGetProfileIdentitySummaries_DoesNotTreatOAuthOnlyCompatEmailAsAlternativeLoginMethod(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 10,
+ Email: "oauth-only@example.com",
+ SignupSource: "oidc",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example.com",
+ ProviderSubject: "oidc-only-subject",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 10, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.False(t, summaries.OIDC.CanUnbind)
+
+ _, err = svc.UnbindUserAuthProvider(context.Background(), 10, "oidc")
+ require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
+ require.Empty(t, repo.unboundProviders)
+}
+
+func TestGetProfileIdentitySummaries_DoesNotTreatCompatBackfilledEmailIdentityAsAlternativeLoginMethod(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 11,
+ Email: "oauth-only@example.com",
+ SignupSource: "wechat",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "oauth-only@example.com",
+ Metadata: map[string]any{
+ "backfill_source": "users.email",
+ "migration": "109_auth_identity_compat_backfill",
+ },
+ },
+ {
+ ProviderType: "wechat",
+ ProviderKey: "wechat",
+ ProviderSubject: "wechat-only-subject",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 11, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.True(t, summaries.Email.Bound)
+ require.False(t, summaries.WeChat.CanUnbind)
+
+ _, err = svc.UnbindUserAuthProvider(context.Background(), 11, "wechat")
+ require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
+ require.Empty(t, repo.unboundProviders)
+}
+
+func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 12,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-12",
+ },
+ },
+ }
+ invalidator := &mockAuthCacheInvalidator{}
+ svc := NewUserService(repo, nil, invalidator, nil)
+
+ user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo")
+
+ require.NoError(t, err)
+ require.Equal(t, []string{"linuxdo"}, repo.unboundProviders)
+ require.Equal(t, int64(12), user.ID)
+ require.Equal(t, []int64{12}, invalidator.invalidatedUserIDs)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user)
+ require.NoError(t, err)
+ require.False(t, summaries.LinuxDo.Bound)
+ require.True(t, summaries.LinuxDo.CanBind)
+}
+
+func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabled(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 15,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ },
+ }
+ settingRepo := &mockUserSettingRepo{
+ values: map[string]string{
+ SettingKeyLinuxDoConnectEnabled: "false",
+ },
+ }
+ svc := NewUserService(repo, settingRepo, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 15, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.False(t, summaries.LinuxDo.Bound)
+ require.False(t, summaries.LinuxDo.CanBind)
+ require.Empty(t, summaries.LinuxDo.BindStartPath)
+}
+
+func TestGetProfileIdentitySummaries_UsesBindStartRoute(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 16,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 16, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.Equal(
+ t,
+ "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
+ summaries.LinuxDo.BindStartPath,
+ )
+ require.Equal(
+ t,
+ "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
+ summaries.OIDC.BindStartPath,
+ )
+ require.Equal(
+ t,
+ "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
+ summaries.WeChat.BindStartPath,
+ )
+}
+
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
@@ -154,6 +573,39 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
+func TestTouchLastActive_UpdatesWhenStale(t *testing.T) {
+ stale := time.Now().Add(-11 * time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &stale,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Equal(t, []int64{42}, repo.updateLastActiveUserIDs)
+ require.Len(t, repo.updateLastActiveAt, 1)
+ require.WithinDuration(t, time.Now(), repo.updateLastActiveAt[0], 2*time.Second)
+}
+
+func TestTouchLastActive_SkipsWhenRecent(t *testing.T) {
+ recent := time.Now().Add(-time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &recent,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Empty(t, repo.updateLastActiveUserIDs)
+ require.Empty(t, repo.updateLastActiveAt)
+}
+
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
@@ -200,3 +652,199 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
+
+func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
+ raw := []byte("small-avatar")
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ expectedSum := sha256.Sum256(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 7,
+ Email: "avatar@example.com",
+ Username: "avatar-user",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
+ require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
+ require.Equal(t, dataURL, updated.AvatarURL)
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/png", updated.AvatarMIME)
+ require.Equal(t, len(raw), updated.AvatarByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
+}
+
+func TestUpdateProfile_CompressesInlineAvatarToTwentyKilobytes(t *testing.T) {
+ var encoded bytes.Buffer
+ for _, size := range []int{192, 224, 256, 288} {
+ encoded.Reset()
+ var img image.RGBA
+ img.Rect = image.Rect(0, 0, size, size)
+ img.Stride = size * 4
+ img.Pix = make([]byte, size*size*4)
+ for y := 0; y < size; y++ {
+ for x := 0; x < size; x++ {
+ offset := y*img.Stride + x*4
+ img.Pix[offset] = uint8((x*x + y*17) % 255)
+ img.Pix[offset+1] = uint8((y*y + x*29) % 255)
+ img.Pix[offset+2] = uint8(((x * y) + x*13 + y*7) % 255)
+ img.Pix[offset+3] = 0xff
+ }
+ }
+ require.NoError(t, png.Encode(&encoded, &img))
+ if encoded.Len() > 20*1024 && encoded.Len() <= maxInlineAvatarBytes {
+ break
+ }
+ }
+ require.Greater(t, encoded.Len(), 20*1024)
+ require.LessOrEqual(t, encoded.Len(), maxInlineAvatarBytes)
+
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(encoded.Bytes())
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 17,
+ Email: "avatar-compress@example.com",
+ Username: "avatar-compress",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 17, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.LessOrEqual(t, repo.upsertAvatarArgs[0].ByteSize, 20*1024)
+ require.Equal(t, "image/jpeg", repo.upsertAvatarArgs[0].ContentType)
+ require.Contains(t, repo.upsertAvatarArgs[0].URL, "data:image/jpeg;base64,")
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/jpeg", updated.AvatarMIME)
+ require.LessOrEqual(t, updated.AvatarByteSize, 20*1024)
+ require.Contains(t, updated.AvatarURL, "data:image/jpeg;base64,")
+ require.NotEmpty(t, updated.AvatarSHA256)
+}
+
+func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
+ raw := make([]byte, maxInlineAvatarBytes+1)
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 8,
+ Email: "large-avatar@example.com",
+ Username: "too-large",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.ErrorIs(t, err, ErrAvatarTooLarge)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Zero(t, repo.updateCalls)
+}
+
+func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
+ remoteURL := "https://cdn.example.com/avatar.png"
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 9,
+ Email: "remote-avatar@example.com",
+ Username: "remote-avatar",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
+ require.Equal(t, remoteURL, updated.AvatarURL)
+ require.Equal(t, "remote_url", updated.AvatarSource)
+ require.Zero(t, updated.AvatarByteSize)
+}
+
+func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
+ empty := ""
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 10,
+ Email: "delete-avatar@example.com",
+ Username: "delete-avatar",
+ AvatarURL: "https://cdn.example.com/old.png",
+ AvatarSource: "remote_url",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
+ AvatarURL: &empty,
+ })
+ require.NoError(t, err)
+ require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, updated.AvatarURL)
+ require.Empty(t, updated.AvatarSource)
+}
+
+func TestUpdateProfile_RollsBackAvatarMutationWhenUserUpdateFails(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 11,
+ Email: "rollback@example.com",
+ AvatarURL: "https://cdn.example.com/original.png",
+ AvatarSource: "remote_url",
+ },
+ updateFn: func(context.Context, *User) error {
+ return errors.New("write user failed")
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ remoteURL := "https://cdn.example.com/new.png"
+ _, err := svc.UpdateProfile(context.Background(), 11, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+
+ require.EqualError(t, err, "update user: write user failed")
+ require.Equal(t, 1, repo.txCalls)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Equal(t, "https://cdn.example.com/original.png", repo.getByIDUser.AvatarURL)
+ require.Equal(t, "remote_url", repo.getByIDUser.AvatarSource)
+}
+
+func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 12,
+ Email: "profile-avatar@example.com",
+ Username: "profile-avatar",
+ },
+ getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
+ return &UserAvatar{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/profile.png",
+ }, nil
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ user, err := svc.GetProfile(context.Background(), 12)
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
+ require.Equal(t, "remote_url", user.AvatarSource)
+}
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index 89d09eef..5f3719be 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -305,7 +305,8 @@ func shouldBypassEmbeddedFrontend(path string) bool {
strings.HasPrefix(trimmed, "/setup/") ||
trimmed == "/health" ||
trimmed == "/responses" ||
- strings.HasPrefix(trimmed, "/responses/")
+ strings.HasPrefix(trimmed, "/responses/") ||
+ strings.HasPrefix(trimmed, "/images/")
}
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql
new file mode 100644
index 00000000..117e3ca3
--- /dev/null
+++ b/backend/migrations/108_auth_identity_foundation_core.sql
@@ -0,0 +1,141 @@
+ALTER TABLE users
+ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
+ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
+ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
+
+UPDATE users
+SET signup_source = 'email'
+WHERE signup_source IS NULL OR signup_source = '';
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'users_signup_source_check'
+ ) THEN
+ ALTER TABLE users
+ ADD CONSTRAINT users_signup_source_check
+ CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
+ END IF;
+END $$;
+
+CREATE TABLE IF NOT EXISTS auth_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ verified_at TIMESTAMPTZ NULL,
+ issuer TEXT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identities_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
+ ON auth_identities (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
+ ON auth_identities (user_id);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
+ ON auth_identities (user_id, provider_type);
+
+CREATE TABLE IF NOT EXISTS auth_identity_channels (
+ id BIGSERIAL PRIMARY KEY,
+ identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ channel VARCHAR(20) NOT NULL,
+ channel_app_id TEXT NOT NULL,
+ channel_subject TEXT NOT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identity_channels_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
+ ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
+ ON auth_identity_channels (identity_id);
+
+CREATE TABLE IF NOT EXISTS pending_auth_sessions (
+ id BIGSERIAL PRIMARY KEY,
+ session_token VARCHAR(255) NOT NULL,
+ intent VARCHAR(40) NOT NULL,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ redirect_to TEXT NOT NULL DEFAULT '',
+ resolved_email TEXT NOT NULL DEFAULT '',
+ registration_password_hash TEXT NOT NULL DEFAULT '',
+ upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
+ local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
+ browser_session_key TEXT NOT NULL DEFAULT '',
+ completion_code_hash TEXT NOT NULL DEFAULT '',
+ completion_code_expires_at TIMESTAMPTZ NULL,
+ email_verified_at TIMESTAMPTZ NULL,
+ password_verified_at TIMESTAMPTZ NULL,
+ totp_verified_at TIMESTAMPTZ NULL,
+ expires_at TIMESTAMPTZ NOT NULL,
+ consumed_at TIMESTAMPTZ NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT pending_auth_sessions_intent_check
+ CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
+ CONSTRAINT pending_auth_sessions_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
+ ON pending_auth_sessions (session_token);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
+ ON pending_auth_sessions (target_user_id);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
+ ON pending_auth_sessions (expires_at);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
+ ON pending_auth_sessions (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
+ ON pending_auth_sessions (completion_code_hash);
+
+CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
+ id BIGSERIAL PRIMARY KEY,
+ pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
+ identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
+ adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
+ adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
+ decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
+ ON identity_adoption_decisions (pending_auth_session_id);
+
+CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
+ ON identity_adoption_decisions (identity_id);
+
+CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
+ id BIGSERIAL PRIMARY KEY,
+ report_type VARCHAR(40) NOT NULL,
+ report_key TEXT NOT NULL,
+ details JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
+ ON auth_identity_migration_reports (report_type);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
+ ON auth_identity_migration_reports (report_type, report_key);
diff --git a/backend/migrations/108a_widen_auth_identity_migration_report_type.sql b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql
new file mode 100644
index 00000000..bc170fb8
--- /dev/null
+++ b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql
@@ -0,0 +1,14 @@
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+ AND COALESCE(character_maximum_length, 0) < 80
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ALTER COLUMN report_type TYPE VARCHAR(80);
+ END IF;
+END $$;
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
new file mode 100644
index 00000000..ddbbedbc
--- /dev/null
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -0,0 +1,125 @@
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'email',
+ 'email',
+ LOWER(BTRIM(u.email)),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'users.email',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(u.email, '')) <> ''
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'linuxdo',
+ 'linuxdo',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'wechat',
+ 'wechat',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+UPDATE users
+SET signup_source = 'linuxdo'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'wechat'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'oidc'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'oidc_synthetic_email_requires_manual_recovery',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities ai
+ WHERE ai.user_id = u.id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ )
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
new file mode 100644
index 00000000..f59b2188
--- /dev/null
+++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
@@ -0,0 +1,59 @@
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
+ granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT user_provider_default_grants_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
+ CONSTRAINT user_provider_default_grants_reason_check
+ CHECK (grant_reason IN ('signup', 'first_bind'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
+ ON user_provider_default_grants (user_id, provider_type, grant_reason);
+
+CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
+ ON user_provider_default_grants (user_id);
+
+CREATE TABLE IF NOT EXISTS user_avatars (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL DEFAULT '',
+ content_type VARCHAR(100) NOT NULL DEFAULT '',
+ byte_size INT NOT NULL DEFAULT 0,
+ sha256 VARCHAR(64) NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
+ ON user_avatars (user_id);
+
+INSERT INTO settings (key, value)
+VALUES
+ ('auth_source_default_email_balance', '0'),
+ ('auth_source_default_email_concurrency', '5'),
+ ('auth_source_default_email_subscriptions', '[]'),
+ ('auth_source_default_email_grant_on_signup', 'false'),
+ ('auth_source_default_email_grant_on_first_bind', 'false'),
+ ('auth_source_default_linuxdo_balance', '0'),
+ ('auth_source_default_linuxdo_concurrency', '5'),
+ ('auth_source_default_linuxdo_subscriptions', '[]'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'false'),
+ ('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
+ ('auth_source_default_oidc_balance', '0'),
+ ('auth_source_default_oidc_concurrency', '5'),
+ ('auth_source_default_oidc_subscriptions', '[]'),
+ ('auth_source_default_oidc_grant_on_signup', 'false'),
+ ('auth_source_default_oidc_grant_on_first_bind', 'false'),
+ ('auth_source_default_wechat_balance', '0'),
+ ('auth_source_default_wechat_concurrency', '5'),
+ ('auth_source_default_wechat_subscriptions', '[]'),
+ ('auth_source_default_wechat_grant_on_signup', 'false'),
+ ('auth_source_default_wechat_grant_on_first_bind', 'false'),
+ ('force_email_on_third_party_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
new file mode 100644
index 00000000..f222a8d4
--- /dev/null
+++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
@@ -0,0 +1,8 @@
+INSERT INTO settings (key, value)
+VALUES
+ ('payment_visible_method_alipay_source', ''),
+ ('payment_visible_method_wxpay_source', ''),
+ ('payment_visible_method_alipay_enabled', 'false'),
+ ('payment_visible_method_wxpay_enabled', 'false'),
+ ('openai_advanced_scheduler_enabled', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
new file mode 100644
index 00000000..d331b824
--- /dev/null
+++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
@@ -0,0 +1,10 @@
+ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30);
+
+UPDATE payment_orders
+SET provider_key = (
+ SELECT provider_key
+ FROM payment_provider_instances
+ WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id
+)
+WHERE provider_key IS NULL
+ AND provider_instance_id IS NOT NULL;
diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
new file mode 100644
index 00000000..15610af0
--- /dev/null
+++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
@@ -0,0 +1,89 @@
+UPDATE auth_identities AS ai
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+ );
+
+UPDATE auth_identity_channels AS channel
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identity_channels AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+ );
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_provider_key_conflict',
+ CAST(ai.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_identity_id', ai.id,
+ 'legacy_user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'canonical_identity_id', canon.id,
+ 'canonical_user_id', canon.user_id,
+ 'same_user', canon.user_id = ai.user_id,
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identities AS ai
+JOIN auth_identities AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_channel_provider_key_conflict',
+ CAST(channel.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_channel_id', channel.id,
+ 'legacy_identity_id', channel.identity_id,
+ 'canonical_channel_id', canon.id,
+ 'canonical_identity_id', canon.identity_id,
+ 'channel', channel.channel,
+ 'channel_app_id', channel.channel_app_id,
+ 'channel_subject', channel.channel_subject,
+ 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE),
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identity_channels AS channel
+JOIN auth_identity_channels AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+LEFT JOIN auth_identities AS legacy_identity
+ ON legacy_identity.id = channel.identity_id
+LEFT JOIN auth_identities AS canonical_identity
+ ON canonical_identity.id = canon.identity_id
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql
new file mode 100644
index 00000000..f84bf822
--- /dev/null
+++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql
@@ -0,0 +1,11 @@
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT '';
+
+CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at
+ ON auth_identity_migration_reports (resolved_at);
diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
new file mode 100644
index 00000000..264da3c9
--- /dev/null
+++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
@@ -0,0 +1,268 @@
+CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_user_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_user_id
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_user_id
+ ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_user_id
+ AND subjects.distinct_user_count = 1
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'linuxdo',
+ 'linuxdo',
+ legacy.provider_user_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM canonical_legacy AS legacy
+WHERE legacy.canonical_row_num = 1
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_union_id
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_union_id
+ ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_union_id
+ AND subjects.distinct_user_count = 1
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'wechat',
+ 'wechat-main',
+ legacy.provider_union_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', legacy.provider_union_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM canonical_legacy AS legacy
+WHERE legacy.canonical_row_num = 1
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
+ meta.metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ CROSS JOIN LATERAL (
+ SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ ) AS meta
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_union_id
+)
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM legacy
+JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_union_id
+ AND subjects.distinct_user_count = 1
+JOIN auth_identities AS ai
+ ON ai.user_id = legacy.user_id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat-main'
+ AND ai.provider_subject = legacy.provider_union_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND legacy.provider_user_id <> ''
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'synthetic_auth_identity:' || ai.id::text,
+ COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'auth_identity_id', ai.id,
+ 'user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM auth_identities AS ai
+WHERE ai.provider_type = 'wechat'
+ AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
+ AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
new file mode 100644
index 00000000..81eb133c
--- /dev/null
+++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
@@ -0,0 +1,525 @@
+CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT)
+RETURNS BOOLEAN
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN TRUE;
+ END IF;
+
+ parsed := input_text::jsonb;
+ RETURN TRUE;
+EXCEPTION
+ WHEN OTHERS THEN
+ RETURN FALSE;
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_invalid_metadata_json',
+ 'legacy_external_identity:' || uei.id::text,
+ jsonb_build_object(
+ 'legacy_identity_id', uei.id,
+ 'user_id', uei.user_id,
+ 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))),
+ 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')),
+ 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')),
+ 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object',
+ 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000),
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM user_external_identities AS uei
+JOIN users AS u ON u.id = uei.user_id
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(uei.metadata, '')) <> ''
+ AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata)
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids,
+ 'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject,
+ to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids
+ FROM (
+ SELECT
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+ ) AS legacy_subjects
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) > 1
+) AS ambiguous
+ ON ambiguous.provider_type = legacy.provider_type
+ AND ambiguous.provider_key = legacy.provider_key
+ AND ambiguous.provider_subject = legacy.provider_subject
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_identity_id', ai.id,
+ 'existing_user_id', ai.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'reason', 'legacy canonical identity subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject
+ FROM (
+ SELECT
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+ ) AS legacy_subjects
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) = 1
+) AS clear_subjects
+ ON clear_subjects.provider_type = legacy.provider_type
+ AND clear_subjects.provider_key = legacy.provider_key
+ AND clear_subjects.provider_subject = legacy.provider_subject
+JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+),
+clear_subjects AS (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject
+ FROM legacy
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) = 1
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject
+ ORDER BY legacy.verified_at DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN clear_subjects
+ ON clear_subjects.provider_type = legacy.provider_type
+ AND clear_subjects.provider_key = legacy.provider_key
+ AND clear_subjects.provider_subject = legacy.provider_subject
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ legacy.provider_type,
+ legacy.provider_key,
+ legacy.provider_subject,
+ legacy.verified_at,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', NULLIF(legacy.provider_union_id, ''),
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM canonical_legacy AS legacy
+LEFT JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE legacy.canonical_row_num = 1
+ AND ai.id IS NULL
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_channel_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_channel_id', channel.id,
+ 'existing_identity_id', existing_ai.id,
+ 'existing_user_id', existing_ai.user_id,
+ 'provider_type', 'wechat',
+ 'provider_key', 'wechat-main',
+ 'provider_subject', legacy.provider_union_id,
+ 'channel', legacy.channel,
+ 'channel_app_id', legacy.channel_app_id,
+ 'channel_subject', legacy.provider_user_id,
+ 'reason', 'legacy channel subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN (
+ SELECT
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ GROUP BY BTRIM(COALESCE(uei.provider_union_id, ''))
+ HAVING COUNT(DISTINCT uei.user_id) = 1
+) AS clear_subjects
+ ON clear_subjects.provider_subject = legacy.provider_union_id
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+JOIN auth_identities AS existing_ai
+ ON existing_ai.id = channel.identity_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND existing_ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+),
+clear_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject
+ FROM legacy
+ GROUP BY provider_union_id
+ HAVING COUNT(DISTINCT user_id) = 1
+)
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ legacy_ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM legacy
+JOIN clear_subjects
+ ON clear_subjects.provider_subject = legacy.provider_union_id
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+LEFT JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND channel.id IS NULL
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identities_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identities
+ ADD CONSTRAINT auth_identities_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_channels_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_channels
+ ADD CONSTRAINT auth_identity_channels_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_migration_reports_details_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check
+ CHECK (jsonb_typeof(details) = 'object');
+ END IF;
+END $$;
+
+DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
+DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql
new file mode 100644
index 00000000..56a5fe2d
--- /dev/null
+++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql
@@ -0,0 +1,2 @@
+ALTER TABLE payment_orders
+ADD COLUMN IF NOT EXISTS provider_snapshot JSONB;
diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql
new file mode 100644
index 00000000..18782617
--- /dev/null
+++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql
@@ -0,0 +1,25 @@
+INSERT INTO settings (key, value)
+VALUES
+ (
+ 'wechat_connect_open_enabled',
+ CASE
+ WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
+ WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
+ WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false'
+ ELSE 'true'
+ END
+ ),
+ (
+ 'wechat_connect_mp_enabled',
+ CASE
+ WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
+ WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
+ WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true'
+ ELSE 'false'
+ END
+ ),
+ ('auth_source_default_email_grant_on_signup', 'false'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'false'),
+ ('auth_source_default_oidc_grant_on_signup', 'false'),
+ ('auth_source_default_wechat_grant_on_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql
new file mode 100644
index 00000000..15e2c15f
--- /dev/null
+++ b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql
@@ -0,0 +1,6 @@
+-- Intentionally left as a no-op.
+-- The online index rollout lives in 120_enforce_payment_orders_out_trade_no_unique_notx.sql
+DO $$
+BEGIN
+ NULL;
+END $$;
diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql
new file mode 100644
index 00000000..638d8622
--- /dev/null
+++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql
@@ -0,0 +1,10 @@
+-- Build the payment order uniqueness guarantee online.
+-- The migration runner performs an explicit duplicate out_trade_no precheck and
+-- drops any stale invalid paymentorder_out_trade_no_unique index before retrying.
+-- Create the new partial unique index concurrently first so writes keep flowing,
+-- then remove the legacy index name once the replacement is ready.
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
diff --git a/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql
new file mode 100644
index 00000000..ef2599dc
--- /dev/null
+++ b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql
@@ -0,0 +1,22 @@
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = 'payment_orders'
+ AND indexname = 'paymentorder_out_trade_no_unique'
+ ) THEN
+ IF EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = 'payment_orders'
+ AND indexname = 'paymentorder_out_trade_no'
+ ) THEN
+ EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no';
+ END IF;
+
+ EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no';
+ END IF;
+END $$;
diff --git a/backend/migrations/121_auth_identity_migration_report_type_widen.sql b/backend/migrations/121_auth_identity_migration_report_type_widen.sql
new file mode 100644
index 00000000..66bfb44a
--- /dev/null
+++ b/backend/migrations/121_auth_identity_migration_report_type_widen.sql
@@ -0,0 +1,2 @@
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(80);
diff --git a/backend/migrations/122_pending_auth_completion_token_cleanup.sql b/backend/migrations/122_pending_auth_completion_token_cleanup.sql
new file mode 100644
index 00000000..e6341142
--- /dev/null
+++ b/backend/migrations/122_pending_auth_completion_token_cleanup.sql
@@ -0,0 +1,15 @@
+UPDATE pending_auth_sessions
+SET
+ local_flow_state = jsonb_set(
+ local_flow_state,
+ '{completion_response}',
+ ((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'),
+ true
+ )
+WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object'
+ AND (
+ (local_flow_state -> 'completion_response') ? 'access_token'
+ OR (local_flow_state -> 'completion_response') ? 'refresh_token'
+ OR (local_flow_state -> 'completion_response') ? 'expires_in'
+ OR (local_flow_state -> 'completion_response') ? 'token_type'
+ );
diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql
new file mode 100644
index 00000000..4388285a
--- /dev/null
+++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql
@@ -0,0 +1,68 @@
+-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value.
+-- Rows still matching the migration-110 default payload and timestamp window are treated as
+-- untouched legacy defaults; any remaining legacy true values are reported for manual review.
+
+WITH migration_110 AS (
+ SELECT applied_at
+ FROM schema_migrations
+ WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
+),
+providers AS (
+ SELECT provider_type
+ FROM (
+ VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
+ ) AS providers(provider_type)
+),
+legacy_provider_defaults AS (
+ SELECT providers.provider_type
+ FROM providers
+ CROSS JOIN migration_110
+ JOIN settings balance
+ ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
+ JOIN settings concurrency
+ ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
+ JOIN settings subscriptions
+ ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
+ JOIN settings grant_on_signup
+ ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
+ JOIN settings grant_on_first_bind
+ ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
+ WHERE balance.value = '0'
+ AND concurrency.value = '5'
+ AND subscriptions.value = '[]'
+ AND grant_on_signup.value = 'true'
+ AND grant_on_first_bind.value = 'false'
+ AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+),
+updated_signup_grants AS (
+ UPDATE settings
+ SET
+ value = 'false',
+ updated_at = NOW()
+ FROM legacy_provider_defaults
+ WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
+ AND settings.value = 'true'
+ RETURNING legacy_provider_defaults.provider_type
+)
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_auth_source_signup_grant_review',
+ providers.provider_type,
+ jsonb_build_object(
+ 'provider_type', providers.provider_type,
+ 'current_value', grant_on_signup.value,
+ 'auto_backfilled', FALSE,
+ 'reason', 'legacy_true_default_not_auto_backfilled'
+ )
+FROM providers
+JOIN settings grant_on_signup
+ ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
+LEFT JOIN updated_signup_grants
+ ON updated_signup_grants.provider_type = providers.provider_type
+WHERE grant_on_signup.value = 'true'
+ AND updated_signup_grants.provider_type IS NULL
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/124_backfill_legacy_oidc_security_flags.sql b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql
new file mode 100644
index 00000000..e68bb11a
--- /dev/null
+++ b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql
@@ -0,0 +1,32 @@
+-- Preserve legacy OIDC behavior for upgraded installs that predate the
+-- introduction of secure PKCE/id_token defaults. Fresh installs continue to
+-- inherit runtime defaults when these rows are absent.
+
+WITH legacy_oidc_install AS (
+ SELECT 1
+ FROM settings
+ WHERE key IN (
+ 'oidc_connect_enabled',
+ 'oidc_connect_client_id',
+ 'oidc_connect_authorize_url',
+ 'oidc_connect_token_url',
+ 'oidc_connect_issuer_url',
+ 'oidc_connect_userinfo_url',
+ 'oidc_connect_frontend_redirect_url'
+ )
+ LIMIT 1
+)
+INSERT INTO settings (key, value)
+SELECT defaults.key, 'false'
+FROM legacy_oidc_install
+CROSS JOIN (
+ VALUES
+ ('oidc_connect_use_pkce'),
+ ('oidc_connect_validate_id_token')
+) AS defaults(key)
+WHERE NOT EXISTS (
+ SELECT 1
+ FROM settings existing
+ WHERE existing.key = defaults.key
+)
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go
new file mode 100644
index 00000000..798ae0fe
--- /dev/null
+++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go
@@ -0,0 +1,129 @@
+package migrations
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestMigration112UsesIdempotentAddColumn(t *testing.T) {
+ content, err := FS.ReadFile("112_add_payment_order_provider_key_snapshot.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30)")
+ require.NotContains(t, sql, "ADD COLUMN provider_key VARCHAR(30);")
+}
+
+func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) {
+ content, err := FS.ReadFile("118_wechat_dual_mode_and_auth_source_defaults.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.NotContains(t, sql, "UPDATE settings")
+ require.NotContains(t, sql, "SET value = 'false'")
+ require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING"))
+ require.Contains(t, sql, "THEN ''")
+}
+
+func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) {
+ preflightContent, err := FS.ReadFile("108a_widen_auth_identity_migration_report_type.sql")
+ require.NoError(t, err)
+
+ preflightSQL := string(preflightContent)
+ require.Contains(t, preflightSQL, "ALTER TABLE auth_identity_migration_reports")
+ require.Contains(t, preflightSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
+
+ content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.NotContains(t, sql, "ALTER TABLE auth_identity_migration_reports")
+
+ followupContent, err := FS.ReadFile("121_auth_identity_migration_report_type_widen.sql")
+ require.NoError(t, err)
+
+ followupSQL := string(followupContent)
+ require.Contains(t, followupSQL, "ALTER TABLE auth_identity_migration_reports")
+ require.Contains(t, followupSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
+}
+
+func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) {
+ content, err := FS.ReadFile("119_enforce_payment_orders_out_trade_no_unique.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "120_enforce_payment_orders_out_trade_no_unique_notx.sql")
+ require.Contains(t, sql, "NULL;")
+ require.NotContains(t, sql, "CREATE UNIQUE INDEX")
+ require.NotContains(t, sql, "DROP INDEX")
+
+ followupContent, err := FS.ReadFile("120_enforce_payment_orders_out_trade_no_unique_notx.sql")
+ require.NoError(t, err)
+
+ followupSQL := string(followupContent)
+ require.Contains(t, followupSQL, "explicit duplicate out_trade_no precheck")
+ require.Contains(t, followupSQL, "stale invalid paymentorder_out_trade_no_unique index")
+ require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique")
+ require.NotContains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique")
+ require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
+ require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
+
+ alignmentContent, err := FS.ReadFile("120a_align_payment_orders_out_trade_no_index_name.sql")
+ require.NoError(t, err)
+
+ alignmentSQL := string(alignmentContent)
+ require.Contains(t, alignmentSQL, "paymentorder_out_trade_no_unique")
+ require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no")
+}
+
+func TestMigration110SeedsAuthSourceSignupGrantsDisabledByDefault(t *testing.T) {
+ content, err := FS.ReadFile("110_pending_auth_and_provider_default_grants.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "('auth_source_default_email_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_linuxdo_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_oidc_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_wechat_grant_on_signup', 'false')")
+ require.NotContains(t, sql, "('auth_source_default_email_grant_on_signup', 'true')")
+}
+
+func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) {
+ content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "UPDATE pending_auth_sessions")
+ require.Contains(t, sql, "completion_response")
+ require.Contains(t, sql, "access_token")
+ require.Contains(t, sql, "refresh_token")
+ require.Contains(t, sql, "expires_in")
+ require.Contains(t, sql, "token_type")
+}
+
+func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) {
+ content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql")
+ require.Contains(t, sql, "schema_migrations")
+ require.Contains(t, sql, "updated_at")
+ require.Contains(t, sql, "'_grant_on_signup'")
+ require.Contains(t, sql, "value = 'false'")
+ require.Contains(t, sql, "auth_identity_migration_reports")
+}
+
+func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
+ content, err := FS.ReadFile("124_backfill_legacy_oidc_security_flags.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "oidc_connect_use_pkce")
+ require.Contains(t, sql, "oidc_connect_validate_id_token")
+ require.Contains(t, sql, "ON CONFLICT (key) DO NOTHING")
+ require.Contains(t, sql, "oidc_connect_enabled")
+ require.Contains(t, sql, "'false'")
+}
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index df586de2..c53c430e 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -865,7 +865,7 @@ linuxdo_connect:
frontend_redirect_url: "/auth/linuxdo/callback"
token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none
# 注意:当 token_auth_method=none(public client)时,必须启用 PKCE
- use_pkce: false
+ use_pkce: true
userinfo_email_path: ""
userinfo_id_path: ""
userinfo_username_path: ""
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index b66a791c..9322f7bf 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -22,13 +22,18 @@ Sub2API has a built-in payment system that enables user self-service top-up with
| Provider | Payment Methods | Description |
|----------|----------------|-------------|
| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol |
-| **Alipay (Direct)** | PC Page Pay, H5 Mobile Pay | Direct integration with Alipay Open Platform, auto-switches by device |
-| **WeChat Pay (Direct)** | Native QR Code, H5 Pay | Direct integration with WeChat Pay APIv3, mobile-first H5 |
+| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links |
+| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing |
| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support |
-> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
+> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
-> **EasyPay Recommendation**: [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`) is recommended as an EasyPay provider (link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it). ZPay supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
+> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
+>
+> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
+> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer.
+>
+> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
---
@@ -56,9 +61,18 @@ Configure the following in Admin Dashboard **Settings → Payment Settings**:
| **Minimum Amount** | Minimum single top-up amount | 1 |
| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - |
| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - |
-| **Order Timeout** | Order timeout in minutes (minimum 1) | 5 |
+| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 |
| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
-| **Load Balance Strategy** | Strategy for selecting provider instances | Least Amount |
+| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin |
+
+### Frontend Visible Method Routing
+
+The current payment UX keeps the frontend method list unified and does not expose provider brands directly:
+
+- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay`
+- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat`
+- Each visible method can route to only one source at a time
+- If a visible method is enabled without a selected source, the frontend will not expose that method
### Load Balance Strategies
@@ -108,7 +122,7 @@ Compatible with any payment service that implements the EasyPay protocol.
### Alipay (Direct)
-Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile pay.
+Direct integration with Alipay Open Platform. Desktop flows return a QR code for in-page display, while mobile flows return an Alipay WAP/app redirect URL.
| Parameter | Description | Required |
|-----------|-------------|----------|
@@ -118,7 +132,7 @@ Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile
### WeChat Pay (Direct)
-Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment.
+Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment.
| Parameter | Description | Required |
|-----------|-------------|----------|
@@ -127,8 +141,8 @@ Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment
| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
| **APIv3 Key** | 32-byte APIv3 key | Yes |
| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
-| **WeChat Pay Public Key ID** | WeChat Pay public key ID | No |
-| **Certificate Serial Number** | Merchant certificate serial number | No |
+| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
+| **Certificate Serial Number** | Merchant certificate serial number | Yes |
### Stripe
@@ -215,8 +229,8 @@ User selects amount and payment method
▼
User completes payment
├─ EasyPay → QR code / H5 redirect
- ├─ Alipay → PC page pay / H5 mobile pay
- ├─ WeChat Pay → Native QR / H5 pay
+ ├─ Alipay → Desktop QR / mobile Alipay redirect
+ ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI
└─ Stripe → Payment Element (card/Alipay/WeChat/etc.)
│
▼
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 9d96557f..0fbc198a 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -22,13 +22,18 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| 服务商 | 支付方式 | 说明 |
|--------|---------|------|
| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 |
-| **支付宝官方** | 支付宝 PC 页面支付、H5 手机网站支付 | 直接对接支付宝开放平台,自动根据终端切换 |
-| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 |
+| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 |
+| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 |
| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。
+> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
-> **EasyPay 推荐**:个人推荐 [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`)作为 EasyPay 服务商(链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉)。ZPay 支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
+> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
+>
+> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
+> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。
+>
+> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
---
@@ -56,9 +61,18 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **最低金额** | 单笔最低充值金额 | 1 |
| **最高金额** | 单笔最高充值金额(留空表示不限制) | - |
| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - |
-| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 5 |
+| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 |
| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
-| **负载均衡策略** | 多服务商实例时的选择策略 | 最少金额 |
+| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 |
+
+### 前台可见支付方式路由
+
+当前版本对用户统一展示支付方式,不区分官方渠道还是易支付:
+
+- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝`
+- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信`
+- 同一个可见支付方式在同一时刻只能路由到一个来源
+- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式
### 负载均衡策略
@@ -108,7 +122,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
### 支付宝官方
-直接对接支付宝开放平台,支持 PC 页面支付和 H5 手机网站支付。
+直接对接支付宝开放平台。桌面端返回二维码供页面内展示和扫码,移动端返回支付宝手机网站支付跳转链接。
| 参数 | 说明 | 必填 |
|------|------|------|
@@ -118,7 +132,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
### 微信官方
-直接对接微信支付 APIv3,支持 Native 扫码支付和 H5 支付。
+直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。
| 参数 | 说明 | 必填 |
|------|------|------|
@@ -127,8 +141,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
-| **微信支付公钥 ID** | 微信支付公钥 ID | 否 |
-| **商户证书序列号** | 商户证书序列号 | 否 |
+| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
+| **商户证书序列号** | 商户证书序列号 | 是 |
### Stripe
@@ -215,8 +229,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
▼
用户完成支付
├─ EasyPay → 扫码 / H5 跳转
- ├─ 支付宝官方 → PC 页面支付 / H5 手机网站支付
- ├─ 微信官方 → Native 扫码 / H5 支付
+ ├─ 支付宝官方 → 桌面二维码 / 移动端支付宝跳转
+ ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI
└─ Stripe → Payment Element(银行卡/支付宝/微信等)
│
▼
diff --git a/frontend/src/api/__tests__/admin.users.spec.ts b/frontend/src/api/__tests__/admin.users.spec.ts
new file mode 100644
index 00000000..37656b78
--- /dev/null
+++ b/frontend/src/api/__tests__/admin.users.spec.ts
@@ -0,0 +1,117 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { post } = vi.hoisted(() => ({
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post,
+ },
+}))
+
+import {
+ bindUserAuthIdentity,
+ type AdminBindAuthIdentityRequest,
+ type AdminBoundAuthIdentity,
+} from '@/api/admin/users'
+
+type Assert = T
+type IsExact = (
+ (() => G extends T ? 1 : 2) extends (() => G extends U ? 1 : 2)
+ ? ((() => G extends U ? 1 : 2) extends (() => G extends T ? 1 : 2) ? true : false)
+ : false
+)
+
+type ExpectedAdminBindAuthIdentityRequest = {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string
+ metadata?: Record
+ channel?: {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata?: Record
+ }
+}
+
+type ExpectedAdminBoundAuthIdentity = {
+ user_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ verified_at?: string | null
+ issuer?: string | null
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ channel?: {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ } | null
+}
+
+const requestContractExact: Assert<
+ IsExact
+> = true
+const responseContractExact: Assert<
+ IsExact
+> = true
+
+describe('admin users api auth identity binding', () => {
+ beforeEach(() => {
+ post.mockReset()
+ })
+
+ it('posts the backend-compatible auth identity bind payload and returns the backend response shape', async () => {
+ const payload: AdminBindAuthIdentityRequest = {
+ provider_type: 'wechat',
+ provider_key: 'wechat-main',
+ provider_subject: 'union-123',
+ metadata: { source: 'admin-repair' },
+ channel: {
+ channel: 'open',
+ channel_app_id: 'wx-open',
+ channel_subject: 'openid-123',
+ metadata: { scene: 'migration' },
+ },
+ }
+
+ const response: AdminBoundAuthIdentity = {
+ user_id: 9,
+ provider_type: 'wechat',
+ provider_key: 'wechat-main',
+ provider_subject: 'union-123',
+ verified_at: '2026-04-22T00:00:00Z',
+ issuer: null,
+ metadata: { source: 'admin-repair' },
+ created_at: '2026-04-22T00:00:00Z',
+ updated_at: '2026-04-22T00:00:00Z',
+ channel: {
+ channel: 'open',
+ channel_app_id: 'wx-open',
+ channel_subject: 'openid-123',
+ metadata: { scene: 'migration' },
+ created_at: '2026-04-22T00:00:00Z',
+ updated_at: '2026-04-22T00:00:00Z',
+ },
+ }
+ post.mockResolvedValue({ data: response })
+
+ const result = await bindUserAuthIdentity(9, payload)
+
+ expect(post).toHaveBeenCalledWith('/admin/users/9/auth-identities', payload)
+ expect(result).toEqual(response)
+ })
+
+ it('keeps bind auth identity request and response types aligned with the backend contract', () => {
+ expect(requestContractExact).toBe(true)
+ expect(responseContractExact).toBe(true)
+ })
+})
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
new file mode 100644
index 00000000..a484d7ed
--- /dev/null
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -0,0 +1,184 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const post = vi.fn()
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post
+ }
+}))
+
+describe('oauth adoption auth api', () => {
+ beforeEach(() => {
+ post.mockReset()
+ post.mockResolvedValue({ data: {} })
+ localStorage.clear()
+ document.cookie = 'oauth_bind_access_token=; Max-Age=0; path=/'
+ })
+
+ it('posts adoption decisions when exchanging pending oauth completion', async () => {
+ const { exchangePendingOAuthCompletion } = await import('@/api/auth')
+
+ await exchangePendingOAuthCompletion({
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts bind-login decisions when finalizing pending oauth bind flow', async () => {
+ const { completePendingOAuthBindLogin } = await import('@/api/auth')
+
+ await completePendingOAuthBindLogin({
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts linuxdo invitation completion with adoption decisions', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts linuxdo create-account completion with adoption decisions', async () => {
+ const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth')
+
+ await createPendingLinuxDoOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts oidc invitation completion with adoption decisions', async () => {
+ const { completeOIDCOAuthRegistration } = await import('@/api/auth')
+
+ await completeOIDCOAuthRegistration('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts oidc create-account completion with adoption decisions', async () => {
+ const { createPendingOIDCOAuthAccount } = await import('@/api/auth')
+
+ await createPendingOIDCOAuthAccount('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts wechat invitation completion with adoption decisions', async () => {
+ const { completeWeChatOAuthRegistration } = await import('@/api/auth')
+
+ await completeWeChatOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts wechat create-account completion with adoption decisions', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: false
+ })
+ })
+
+ it('classifies oauth completion results as login or bind', async () => {
+ const { getOAuthCompletionKind } = await import('@/api/auth')
+
+ expect(getOAuthCompletionKind({ access_token: 'access-token' })).toBe('login')
+ expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind')
+ })
+
+ it('provides bind-login utility helpers for invitation and suggested profile states', async () => {
+ const {
+ getPendingOAuthBindLoginKind,
+ hasPendingOAuthSuggestedProfile,
+ isPendingOAuthCreateAccountRequired
+ } = await import('@/api/auth')
+
+ expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login')
+ expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind')
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'invitation_required'
+ })
+ ).toBe(true)
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'other'
+ })
+ ).toBe(false)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_display_name: 'OAuth Nick'
+ })
+ ).toBe(true)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_avatar_url: 'https://cdn.example/avatar.png'
+ })
+ ).toBe(true)
+ expect(hasPendingOAuthSuggestedProfile({})).toBe(false)
+ })
+
+ it('requests an HttpOnly oauth bind cookie before redirect binding', async () => {
+ localStorage.setItem('auth_token', 'access-token-value')
+ const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth')
+
+ await prepareOAuthBindAccessTokenCookie()
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/bind-token')
+ })
+})
diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts
index 0f663e76..a46c39eb 100644
--- a/frontend/src/api/__tests__/client.spec.ts
+++ b/frontend/src/api/__tests__/client.spec.ts
@@ -91,6 +91,22 @@ describe('API Client', () => {
const config = adapter.mock.calls[0][0]
expect(config.params?.timezone).toBeUndefined()
})
+
+ it('请求默认带 withCredentials 以支持跨域 cookie', async () => {
+ const adapter = vi.fn().mockResolvedValue({
+ status: 200,
+ data: { code: 0, data: {} },
+ headers: {},
+ config: {},
+ statusText: 'OK',
+ })
+ apiClient.defaults.adapter = adapter
+
+ await apiClient.post('/auth/oauth/bind-token')
+
+ const config = adapter.mock.calls[0][0]
+ expect(config.withCredentials).toBe(true)
+ })
})
// --- 响应拦截器 ---
diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts
new file mode 100644
index 00000000..e38fba57
--- /dev/null
+++ b/frontend/src/api/__tests__/payment.spec.ts
@@ -0,0 +1,40 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { get, post } = vi.hoisted(() => ({
+ get: vi.fn(),
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ get,
+ post,
+ },
+}))
+
+import { paymentAPI } from '@/api/payment'
+
+describe('payment api', () => {
+ beforeEach(() => {
+ get.mockReset()
+ post.mockReset()
+ get.mockResolvedValue({ data: {} })
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('keeps legacy public out_trade_no verification for upgrade compatibility', async () => {
+ await paymentAPI.verifyOrderPublic('legacy-order-no')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/verify', {
+ out_trade_no: 'legacy-order-no',
+ })
+ })
+
+ it('keeps signed public resume-token resolve endpoint', async () => {
+ await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', {
+ resume_token: 'resume-token-123',
+ })
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
new file mode 100644
index 00000000..10f6247a
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
@@ -0,0 +1,131 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ type UpdateSettingsRequest,
+} from "@/api/admin/settings";
+
+describe("admin settings auth source defaults helpers", () => {
+ it("builds auth source defaults state from flat settings fields", () => {
+ const state = buildAuthSourceDefaultsState({
+ auth_source_default_email_balance: 9.5,
+ auth_source_default_email_concurrency: 3,
+ auth_source_default_email_subscriptions: [
+ { group_id: 1, validity_days: 30 },
+ ],
+ auth_source_default_email_grant_on_signup: false,
+ auth_source_default_email_grant_on_first_bind: true,
+ auth_source_default_linuxdo_balance: 6,
+ auth_source_default_linuxdo_concurrency: 8,
+ auth_source_default_linuxdo_subscriptions: [
+ { group_id: 2, validity_days: 60 },
+ ],
+ auth_source_default_linuxdo_grant_on_signup: true,
+ auth_source_default_linuxdo_grant_on_first_bind: false,
+ });
+
+ expect(state.email).toEqual({
+ balance: 9.5,
+ concurrency: 3,
+ subscriptions: [{ group_id: 1, validity_days: 30 }],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ });
+ expect(state.linuxdo).toEqual({
+ balance: 6,
+ concurrency: 8,
+ subscriptions: [{ group_id: 2, validity_days: 60 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ });
+ expect(state.oidc).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ });
+ expect(state.wechat).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ });
+ });
+
+ it("defaults grant-on-signup to disabled when settings are missing", () => {
+ const state = buildAuthSourceDefaultsState({});
+
+ expect(state.email.grant_on_signup).toBe(false);
+ expect(state.linuxdo.grant_on_signup).toBe(false);
+ expect(state.oidc.grant_on_signup).toBe(false);
+ expect(state.wechat.grant_on_signup).toBe(false);
+ });
+
+ it("appends auth source defaults back onto update payload", () => {
+ const payload: UpdateSettingsRequest = {
+ site_name: "Sub2API",
+ };
+
+ appendAuthSourceDefaultsToUpdateRequest(payload, {
+ email: {
+ balance: 1.25,
+ concurrency: 2,
+ subscriptions: [{ group_id: 3, validity_days: 7 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ },
+ linuxdo: {
+ balance: 0,
+ concurrency: 6,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ },
+ oidc: {
+ balance: 4,
+ concurrency: 9,
+ subscriptions: [{ group_id: 9, validity_days: 90 }],
+ grant_on_signup: true,
+ grant_on_first_bind: true,
+ },
+ wechat: {
+ balance: 2,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ },
+ });
+
+ expect(payload).toMatchObject({
+ site_name: "Sub2API",
+ auth_source_default_email_balance: 1.25,
+ auth_source_default_email_concurrency: 2,
+ auth_source_default_email_subscriptions: [
+ { group_id: 3, validity_days: 7 },
+ ],
+ auth_source_default_email_grant_on_signup: true,
+ auth_source_default_email_grant_on_first_bind: false,
+ auth_source_default_linuxdo_balance: 0,
+ auth_source_default_linuxdo_concurrency: 6,
+ auth_source_default_linuxdo_subscriptions: [],
+ auth_source_default_linuxdo_grant_on_signup: false,
+ auth_source_default_linuxdo_grant_on_first_bind: true,
+ auth_source_default_oidc_balance: 4,
+ auth_source_default_oidc_concurrency: 9,
+ auth_source_default_oidc_subscriptions: [
+ { group_id: 9, validity_days: 90 },
+ ],
+ auth_source_default_oidc_grant_on_signup: true,
+ auth_source_default_oidc_grant_on_first_bind: true,
+ auth_source_default_wechat_balance: 2,
+ auth_source_default_wechat_concurrency: 5,
+ auth_source_default_wechat_subscriptions: [],
+ auth_source_default_wechat_grant_on_signup: false,
+ auth_source_default_wechat_grant_on_first_bind: false,
+ });
+ });
+});
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
new file mode 100644
index 00000000..ad355afe
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -0,0 +1,63 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
+} from '@/api/admin/settings'
+
+describe('admin settings payment visible method helpers', () => {
+ it('normalizes aliases into canonical source keys per visible method', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay')
+
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay')
+ })
+
+ it('rejects unknown or cross-method source values', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('')
+ })
+
+ it('exposes method-scoped source options instead of arbitrary strings', () => {
+ expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
+ {
+ value: '',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
+ },
+ {
+ value: 'official_alipay',
+ labelZh: '支付宝官方',
+ labelEn: 'Official Alipay',
+ },
+ {
+ value: 'easypay_alipay',
+ labelZh: '易支付支付宝',
+ labelEn: 'EasyPay Alipay',
+ },
+ ])
+
+ expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
+ {
+ value: '',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
+ },
+ {
+ value: 'official_wxpay',
+ labelZh: '微信官方',
+ labelEn: 'Official WeChat Pay',
+ },
+ {
+ value: 'easypay_wxpay',
+ labelZh: '易支付微信',
+ labelEn: 'EasyPay WeChat Pay',
+ },
+ ])
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
new file mode 100644
index 00000000..eccb7214
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
@@ -0,0 +1,21 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ defaultWeChatConnectScopesForMode,
+ normalizeWeChatConnectMode,
+} from "@/api/admin/settings";
+
+describe("admin settings wechat connect helpers", () => {
+ it("normalizes legacy or noisy mode values to the backend contract", () => {
+ expect(normalizeWeChatConnectMode("OPEN")).toBe("open");
+ expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open");
+ expect(normalizeWeChatConnectMode("mp")).toBe("mp");
+ expect(normalizeWeChatConnectMode("official_account")).toBe("mp");
+ expect(normalizeWeChatConnectMode("unknown")).toBe("open");
+ });
+
+ it("maps each mode to the backend default scopes", () => {
+ expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login");
+ expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo");
+ });
+});
diff --git a/frontend/src/api/__tests__/user.spec.ts b/frontend/src/api/__tests__/user.spec.ts
new file mode 100644
index 00000000..887046da
--- /dev/null
+++ b/frontend/src/api/__tests__/user.spec.ts
@@ -0,0 +1,32 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+
+describe('user api oauth binding urls', () => {
+ beforeEach(() => {
+ vi.resetModules()
+ vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/api/v1')
+ })
+
+ afterEach(() => {
+ vi.unstubAllEnvs()
+ })
+
+ it('builds third-party bind urls against the bind start endpoint', async () => {
+ const { buildOAuthBindingStartURL } = await import('@/api/user')
+
+ expect(buildOAuthBindingStartURL('linuxdo', { redirectTo: '/settings/profile' })).toBe(
+ 'https://api.example.com/api/v1/auth/oauth/linuxdo/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user'
+ )
+ expect(
+ buildOAuthBindingStartURL('wechat', {
+ redirectTo: '/settings/profile',
+ wechatOAuthSettings: {
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ wechat_oauth_mobile_enabled: false
+ }
+ })
+ ).toBe(
+ 'https://api.example.com/api/v1/auth/oauth/wechat/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user&mode=open'
+ )
+ })
+})
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 1e4a3053..0403b0f3 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -3,12 +3,293 @@
* Handles system settings management for administrators
*/
-import { apiClient } from '../client'
-import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types'
+import { apiClient } from "../client";
+import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types";
export interface DefaultSubscriptionSetting {
- group_id: number
- validity_days: number
+ group_id: number;
+ validity_days: number;
+}
+
+export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat";
+
+export interface AuthSourceDefaultsValue {
+ balance: number;
+ concurrency: number;
+ subscriptions: DefaultSubscriptionSetting[];
+ grant_on_signup: boolean;
+ grant_on_first_bind: boolean;
+}
+
+export type AuthSourceDefaultsState = Record<
+ AuthSourceType,
+ AuthSourceDefaultsValue
+>;
+export type PaymentVisibleMethod = "alipay" | "wxpay";
+export type PaymentVisibleMethodSource =
+ | ""
+ | "official_alipay"
+ | "easypay_alipay"
+ | "official_wxpay"
+ | "easypay_wxpay";
+export type WeChatConnectMode = "open" | "mp" | "mobile";
+
+export interface PaymentVisibleMethodSourceOption {
+ value: PaymentVisibleMethodSource;
+ labelZh: string;
+ labelEn: string;
+}
+
+export interface WeChatConnectModeOption {
+ value: WeChatConnectMode;
+ labelZh: string;
+ labelEn: string;
+}
+
+const AUTH_SOURCE_TYPES: AuthSourceType[] = [
+ "email",
+ "linuxdo",
+ "oidc",
+ "wechat",
+];
+const AUTH_SOURCE_DEFAULT_BALANCE = 0;
+const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5;
+const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
+ PaymentVisibleMethod,
+ PaymentVisibleMethodSourceOption[]
+> = {
+ alipay: [
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_alipay",
+ labelZh: "支付宝官方",
+ labelEn: "Official Alipay",
+ },
+ {
+ value: "easypay_alipay",
+ labelZh: "易支付支付宝",
+ labelEn: "EasyPay Alipay",
+ },
+ ],
+ wxpay: [
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_wxpay",
+ labelZh: "微信官方",
+ labelEn: "Official WeChat Pay",
+ },
+ {
+ value: "easypay_wxpay",
+ labelZh: "易支付微信",
+ labelEn: "EasyPay WeChat Pay",
+ },
+ ],
+};
+const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record<
+ PaymentVisibleMethod,
+ Record
+> = {
+ alipay: {
+ official_alipay: "official_alipay",
+ alipay: "official_alipay",
+ alipay_direct: "official_alipay",
+ official: "official_alipay",
+ easypay_alipay: "easypay_alipay",
+ easypay: "easypay_alipay",
+ },
+ wxpay: {
+ official_wxpay: "official_wxpay",
+ wxpay: "official_wxpay",
+ wxpay_direct: "official_wxpay",
+ wechat: "official_wxpay",
+ official: "official_wxpay",
+ easypay_wxpay: "easypay_wxpay",
+ easypay: "easypay_wxpay",
+ },
+};
+const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [
+ { value: "open", labelZh: "PC 应用", labelEn: "PC App" },
+ {
+ value: "mp",
+ labelZh: "公众号",
+ labelEn: "Official Account",
+ },
+ {
+ value: "mobile",
+ labelZh: "移动应用",
+ labelEn: "Mobile App",
+ },
+];
+const WECHAT_CONNECT_MODE_ALIASES: Record = {
+ open: "open",
+ open_platform: "open",
+ official: "open",
+ wx_open: "open",
+ mp: "mp",
+ official_account: "mp",
+ wechat_mp: "mp",
+ mini_program: "mp",
+ mobile: "mobile",
+ mobile_app: "mobile",
+ native_app: "mobile",
+};
+
+export function normalizeDefaultSubscriptionSettings(
+ subscriptions: DefaultSubscriptionSetting[] | null | undefined,
+): DefaultSubscriptionSetting[] {
+ if (!Array.isArray(subscriptions)) return [];
+
+ return subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item) => ({
+ group_id: Math.floor(item.group_id),
+ validity_days: Math.min(
+ 36500,
+ Math.max(1, Math.floor(item.validity_days)),
+ ),
+ }));
+}
+
+export function buildAuthSourceDefaultsState(
+ settings: Partial,
+): AuthSourceDefaultsState {
+ const raw = settings as Record;
+
+ return AUTH_SOURCE_TYPES.reduce((acc, source) => {
+ const subscriptions = raw[`auth_source_default_${source}_subscriptions`];
+ acc[source] = {
+ balance: Number(
+ raw[`auth_source_default_${source}_balance`] ??
+ AUTH_SOURCE_DEFAULT_BALANCE,
+ ),
+ concurrency: Math.max(
+ 1,
+ Number(
+ raw[`auth_source_default_${source}_concurrency`] ??
+ AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ ),
+ subscriptions: normalizeDefaultSubscriptionSettings(
+ Array.isArray(subscriptions)
+ ? (subscriptions as DefaultSubscriptionSetting[])
+ : [],
+ ),
+ grant_on_signup:
+ raw[`auth_source_default_${source}_grant_on_signup`] === true,
+ grant_on_first_bind:
+ raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
+ };
+ return acc;
+ }, {} as AuthSourceDefaultsState);
+}
+
+export function appendAuthSourceDefaultsToUpdateRequest(
+ payload: UpdateSettingsRequest,
+ authSourceDefaults: AuthSourceDefaultsState,
+): UpdateSettingsRequest {
+ const target = payload as Record;
+
+ for (const source of AUTH_SOURCE_TYPES) {
+ const current = authSourceDefaults[source];
+ target[`auth_source_default_${source}_balance`] =
+ Number(current.balance) || 0;
+ target[`auth_source_default_${source}_concurrency`] = Math.max(
+ 1,
+ Math.floor(
+ Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ );
+ target[`auth_source_default_${source}_subscriptions`] =
+ normalizeDefaultSubscriptionSettings(current.subscriptions);
+ target[`auth_source_default_${source}_grant_on_signup`] =
+ current.grant_on_signup;
+ target[`auth_source_default_${source}_grant_on_first_bind`] =
+ current.grant_on_first_bind;
+ }
+
+ return payload;
+}
+
+export function getPaymentVisibleMethodSourceOptions(
+ method: PaymentVisibleMethod,
+): PaymentVisibleMethodSourceOption[] {
+ return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method];
+}
+
+export function normalizePaymentVisibleMethodSource(
+ method: PaymentVisibleMethod,
+ source: unknown,
+): PaymentVisibleMethodSource {
+ if (typeof source !== "string") return "";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "";
+
+ return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? "";
+}
+
+export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] {
+ return WECHAT_CONNECT_MODE_OPTIONS;
+}
+
+export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode {
+ if (typeof source !== "string") return "open";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "open";
+
+ return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open";
+}
+
+export function defaultWeChatConnectScopesForMode(mode: unknown): string {
+ switch (normalizeWeChatConnectMode(mode)) {
+ case "mp":
+ return "snsapi_userinfo";
+ case "mobile":
+ return "";
+ default:
+ return "snsapi_login";
+ }
+}
+
+export function resolveWeChatConnectModeCapabilities(
+ openEnabled: unknown,
+ mpEnabled: unknown,
+ mobileEnabled: unknown,
+ legacyMode: unknown,
+): { openEnabled: boolean; mpEnabled: boolean; mobileEnabled: boolean } {
+ if (
+ typeof openEnabled === "boolean" ||
+ typeof mpEnabled === "boolean" ||
+ typeof mobileEnabled === "boolean"
+ ) {
+ return {
+ openEnabled: openEnabled === true,
+ mpEnabled: mpEnabled === true,
+ mobileEnabled: mobileEnabled === true,
+ };
+ }
+
+ switch (normalizeWeChatConnectMode(legacyMode)) {
+ case "mp":
+ return { openEnabled: false, mpEnabled: true, mobileEnabled: false };
+ case "mobile":
+ return { openEnabled: false, mpEnabled: false, mobileEnabled: true };
+ default:
+ return { openEnabled: true, mpEnabled: false, mobileEnabled: false };
+ }
+}
+
+export function deriveWeChatConnectStoredMode(
+ openEnabled: boolean,
+ mpEnabled: boolean,
+ mobileEnabled: boolean,
+ legacyMode: unknown,
+): WeChatConnectMode {
+ if (mpEnabled) return "mp";
+ if (mobileEnabled) return "mobile";
+ if (openEnabled) return "open";
+ return normalizeWeChatConnectMode(legacyMode);
}
/**
@@ -16,241 +297,327 @@ export interface DefaultSubscriptionSetting {
*/
export interface SystemSettings {
// Registration settings
- registration_enabled: boolean
- email_verify_enabled: boolean
- registration_email_suffix_whitelist: string[]
- promo_code_enabled: boolean
- password_reset_enabled: boolean
- frontend_url: string
- invitation_code_enabled: boolean
- totp_enabled: boolean // TOTP 双因素认证
- totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
+ registration_enabled: boolean;
+ email_verify_enabled: boolean;
+ registration_email_suffix_whitelist: string[];
+ promo_code_enabled: boolean;
+ password_reset_enabled: boolean;
+ frontend_url: string;
+ invitation_code_enabled: boolean;
+ totp_enabled: boolean; // TOTP 双因素认证
+ totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
- default_balance: number
- default_concurrency: number
- default_subscriptions: DefaultSubscriptionSetting[]
+ default_balance: number;
+ default_concurrency: number;
+ default_subscriptions: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
// OEM settings
- site_name: string
- site_logo: string
- site_subtitle: string
- api_base_url: string
- contact_info: string
- doc_url: string
- home_content: string
- hide_ccs_import_button: boolean
- table_default_page_size: number
- table_page_size_options: number[]
- backend_mode_enabled: boolean
- custom_menu_items: CustomMenuItem[]
- custom_endpoints: CustomEndpoint[]
+ site_name: string;
+ site_logo: string;
+ site_subtitle: string;
+ api_base_url: string;
+ contact_info: string;
+ doc_url: string;
+ home_content: string;
+ hide_ccs_import_button: boolean;
+ table_default_page_size: number;
+ table_page_size_options: number[];
+ backend_mode_enabled: boolean;
+ custom_menu_items: CustomMenuItem[];
+ custom_endpoints: CustomEndpoint[];
// SMTP settings
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password_configured: boolean
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password_configured: boolean;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
// Cloudflare Turnstile settings
- turnstile_enabled: boolean
- turnstile_site_key: string
- turnstile_secret_key_configured: boolean
+ turnstile_enabled: boolean;
+ turnstile_site_key: string;
+ turnstile_secret_key_configured: boolean;
// LinuxDo Connect OAuth settings
- linuxdo_connect_enabled: boolean
- linuxdo_connect_client_id: string
- linuxdo_connect_client_secret_configured: boolean
- linuxdo_connect_redirect_url: string
+ linuxdo_connect_enabled: boolean;
+ linuxdo_connect_client_id: string;
+ linuxdo_connect_client_secret_configured: boolean;
+ linuxdo_connect_redirect_url: string;
+
+ // WeChat Connect OAuth settings
+ wechat_connect_enabled: boolean;
+ wechat_connect_app_id: string;
+ wechat_connect_app_secret_configured: boolean;
+ wechat_connect_open_app_id?: string;
+ wechat_connect_open_app_secret_configured?: boolean;
+ wechat_connect_mp_app_id?: string;
+ wechat_connect_mp_app_secret_configured?: boolean;
+ wechat_connect_mobile_app_id?: string;
+ wechat_connect_mobile_app_secret_configured?: boolean;
+ wechat_connect_open_enabled?: boolean;
+ wechat_connect_mp_enabled?: boolean;
+ wechat_connect_mobile_enabled?: boolean;
+ wechat_connect_mode: string;
+ wechat_connect_scopes: string;
+ wechat_connect_redirect_url: string;
+ wechat_connect_frontend_redirect_url: string;
// Generic OIDC OAuth settings
- oidc_connect_enabled: boolean
- oidc_connect_provider_name: string
- oidc_connect_client_id: string
- oidc_connect_client_secret_configured: boolean
- oidc_connect_issuer_url: string
- oidc_connect_discovery_url: string
- oidc_connect_authorize_url: string
- oidc_connect_token_url: string
- oidc_connect_userinfo_url: string
- oidc_connect_jwks_url: string
- oidc_connect_scopes: string
- oidc_connect_redirect_url: string
- oidc_connect_frontend_redirect_url: string
- oidc_connect_token_auth_method: string
- oidc_connect_use_pkce: boolean
- oidc_connect_validate_id_token: boolean
- oidc_connect_allowed_signing_algs: string
- oidc_connect_clock_skew_seconds: number
- oidc_connect_require_email_verified: boolean
- oidc_connect_userinfo_email_path: string
- oidc_connect_userinfo_id_path: string
- oidc_connect_userinfo_username_path: string
+ oidc_connect_enabled: boolean;
+ oidc_connect_provider_name: string;
+ oidc_connect_client_id: string;
+ oidc_connect_client_secret_configured: boolean;
+ oidc_connect_issuer_url: string;
+ oidc_connect_discovery_url: string;
+ oidc_connect_authorize_url: string;
+ oidc_connect_token_url: string;
+ oidc_connect_userinfo_url: string;
+ oidc_connect_jwks_url: string;
+ oidc_connect_scopes: string;
+ oidc_connect_redirect_url: string;
+ oidc_connect_frontend_redirect_url: string;
+ oidc_connect_token_auth_method: string;
+ oidc_connect_use_pkce: boolean;
+ oidc_connect_validate_id_token: boolean;
+ oidc_connect_allowed_signing_algs: string;
+ oidc_connect_clock_skew_seconds: number;
+ oidc_connect_require_email_verified: boolean;
+ oidc_connect_userinfo_email_path: string;
+ oidc_connect_userinfo_id_path: string;
+ oidc_connect_userinfo_username_path: string;
// Model fallback configuration
- enable_model_fallback: boolean
- fallback_model_anthropic: string
- fallback_model_openai: string
- fallback_model_gemini: string
- fallback_model_antigravity: string
+ enable_model_fallback: boolean;
+ fallback_model_anthropic: string;
+ fallback_model_openai: string;
+ fallback_model_gemini: string;
+ fallback_model_antigravity: string;
// Identity patch configuration (Claude -> Gemini)
- enable_identity_patch: boolean
- identity_patch_prompt: string
+ enable_identity_patch: boolean;
+ identity_patch_prompt: string;
// Ops Monitoring (vNext)
- ops_monitoring_enabled: boolean
- ops_realtime_monitoring_enabled: boolean
- ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds: number
+ ops_monitoring_enabled: boolean;
+ ops_realtime_monitoring_enabled: boolean;
+ ops_query_mode_default: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds: number;
// Claude Code version check
- min_claude_code_version: string
- max_claude_code_version: string
+ min_claude_code_version: string;
+ max_claude_code_version: string;
// 分组隔离
- allow_ungrouped_key_scheduling: boolean
+ allow_ungrouped_key_scheduling: boolean;
// Gateway forwarding behavior
- enable_fingerprint_unification: boolean
- enable_metadata_passthrough: boolean
- enable_cch_signing: boolean
- web_search_emulation_enabled?: boolean
+ enable_fingerprint_unification: boolean;
+ enable_metadata_passthrough: boolean;
+ enable_cch_signing: boolean;
+ web_search_emulation_enabled?: boolean;
// Payment configuration
- payment_enabled: boolean
- payment_min_amount: number
- payment_max_amount: number
- payment_daily_limit: number
- payment_order_timeout_minutes: number
- payment_max_pending_orders: number
- payment_enabled_types: string[]
- payment_balance_disabled: boolean
- payment_balance_recharge_multiplier: number
- payment_recharge_fee_rate: number
- payment_load_balance_strategy: string
- payment_product_name_prefix: string
- payment_product_name_suffix: string
- payment_help_image_url: string
- payment_help_text: string
- payment_cancel_rate_limit_enabled: boolean
- payment_cancel_rate_limit_max: number
- payment_cancel_rate_limit_window: number
- payment_cancel_rate_limit_unit: string
- payment_cancel_rate_limit_window_mode: string
+ payment_enabled: boolean;
+ payment_min_amount: number;
+ payment_max_amount: number;
+ payment_daily_limit: number;
+ payment_order_timeout_minutes: number;
+ payment_max_pending_orders: number;
+ payment_enabled_types: string[];
+ payment_balance_disabled: boolean;
+ payment_balance_recharge_multiplier: number;
+ payment_recharge_fee_rate: number;
+ payment_load_balance_strategy: string;
+ payment_product_name_prefix: string;
+ payment_product_name_suffix: string;
+ payment_help_image_url: string;
+ payment_help_text: string;
+ payment_cancel_rate_limit_enabled: boolean;
+ payment_cancel_rate_limit_max: number;
+ payment_cancel_rate_limit_window: number;
+ payment_cancel_rate_limit_unit: string;
+ payment_cancel_rate_limit_window_mode: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
// Balance & quota notification
- balance_low_notify_enabled: boolean
- balance_low_notify_threshold: number
- balance_low_notify_recharge_url: string
- account_quota_notify_enabled: boolean
- account_quota_notify_emails: NotifyEmailEntry[]
+ balance_low_notify_enabled: boolean;
+ balance_low_notify_threshold: number;
+ balance_low_notify_recharge_url: string;
+ account_quota_notify_enabled: boolean;
+ account_quota_notify_emails: NotifyEmailEntry[];
}
export interface UpdateSettingsRequest {
- registration_enabled?: boolean
- email_verify_enabled?: boolean
- registration_email_suffix_whitelist?: string[]
- promo_code_enabled?: boolean
- password_reset_enabled?: boolean
- frontend_url?: string
- invitation_code_enabled?: boolean
- totp_enabled?: boolean // TOTP 双因素认证
- default_balance?: number
- default_concurrency?: number
- default_subscriptions?: DefaultSubscriptionSetting[]
- site_name?: string
- site_logo?: string
- site_subtitle?: string
- api_base_url?: string
- contact_info?: string
- doc_url?: string
- home_content?: string
- hide_ccs_import_button?: boolean
- table_default_page_size?: number
- table_page_size_options?: number[]
- backend_mode_enabled?: boolean
- custom_menu_items?: CustomMenuItem[]
- custom_endpoints?: CustomEndpoint[]
- smtp_host?: string
- smtp_port?: number
- smtp_username?: string
- smtp_password?: string
- smtp_from_email?: string
- smtp_from_name?: string
- smtp_use_tls?: boolean
- turnstile_enabled?: boolean
- turnstile_site_key?: string
- turnstile_secret_key?: string
- linuxdo_connect_enabled?: boolean
- linuxdo_connect_client_id?: string
- linuxdo_connect_client_secret?: string
- linuxdo_connect_redirect_url?: string
- oidc_connect_enabled?: boolean
- oidc_connect_provider_name?: string
- oidc_connect_client_id?: string
- oidc_connect_client_secret?: string
- oidc_connect_issuer_url?: string
- oidc_connect_discovery_url?: string
- oidc_connect_authorize_url?: string
- oidc_connect_token_url?: string
- oidc_connect_userinfo_url?: string
- oidc_connect_jwks_url?: string
- oidc_connect_scopes?: string
- oidc_connect_redirect_url?: string
- oidc_connect_frontend_redirect_url?: string
- oidc_connect_token_auth_method?: string
- oidc_connect_use_pkce?: boolean
- oidc_connect_validate_id_token?: boolean
- oidc_connect_allowed_signing_algs?: string
- oidc_connect_clock_skew_seconds?: number
- oidc_connect_require_email_verified?: boolean
- oidc_connect_userinfo_email_path?: string
- oidc_connect_userinfo_id_path?: string
- oidc_connect_userinfo_username_path?: string
- enable_model_fallback?: boolean
- fallback_model_anthropic?: string
- fallback_model_openai?: string
- fallback_model_gemini?: string
- fallback_model_antigravity?: string
- enable_identity_patch?: boolean
- identity_patch_prompt?: string
- ops_monitoring_enabled?: boolean
- ops_realtime_monitoring_enabled?: boolean
- ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds?: number
- min_claude_code_version?: string
- max_claude_code_version?: string
- allow_ungrouped_key_scheduling?: boolean
- enable_fingerprint_unification?: boolean
- enable_metadata_passthrough?: boolean
- enable_cch_signing?: boolean
+ registration_enabled?: boolean;
+ email_verify_enabled?: boolean;
+ registration_email_suffix_whitelist?: string[];
+ promo_code_enabled?: boolean;
+ password_reset_enabled?: boolean;
+ frontend_url?: string;
+ invitation_code_enabled?: boolean;
+ totp_enabled?: boolean; // TOTP 双因素认证
+ default_balance?: number;
+ default_concurrency?: number;
+ default_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
+ site_name?: string;
+ site_logo?: string;
+ site_subtitle?: string;
+ api_base_url?: string;
+ contact_info?: string;
+ doc_url?: string;
+ home_content?: string;
+ hide_ccs_import_button?: boolean;
+ table_default_page_size?: number;
+ table_page_size_options?: number[];
+ backend_mode_enabled?: boolean;
+ custom_menu_items?: CustomMenuItem[];
+ custom_endpoints?: CustomEndpoint[];
+ smtp_host?: string;
+ smtp_port?: number;
+ smtp_username?: string;
+ smtp_password?: string;
+ smtp_from_email?: string;
+ smtp_from_name?: string;
+ smtp_use_tls?: boolean;
+ turnstile_enabled?: boolean;
+ turnstile_site_key?: string;
+ turnstile_secret_key?: string;
+ linuxdo_connect_enabled?: boolean;
+ linuxdo_connect_client_id?: string;
+ linuxdo_connect_client_secret?: string;
+ linuxdo_connect_redirect_url?: string;
+ wechat_connect_enabled?: boolean;
+ wechat_connect_app_id?: string;
+ wechat_connect_app_secret?: string;
+ wechat_connect_open_app_id?: string;
+ wechat_connect_open_app_secret?: string;
+ wechat_connect_mp_app_id?: string;
+ wechat_connect_mp_app_secret?: string;
+ wechat_connect_mobile_app_id?: string;
+ wechat_connect_mobile_app_secret?: string;
+ wechat_connect_open_enabled?: boolean;
+ wechat_connect_mp_enabled?: boolean;
+ wechat_connect_mobile_enabled?: boolean;
+ wechat_connect_mode?: string;
+ wechat_connect_scopes?: string;
+ wechat_connect_redirect_url?: string;
+ wechat_connect_frontend_redirect_url?: string;
+ oidc_connect_enabled?: boolean;
+ oidc_connect_provider_name?: string;
+ oidc_connect_client_id?: string;
+ oidc_connect_client_secret?: string;
+ oidc_connect_issuer_url?: string;
+ oidc_connect_discovery_url?: string;
+ oidc_connect_authorize_url?: string;
+ oidc_connect_token_url?: string;
+ oidc_connect_userinfo_url?: string;
+ oidc_connect_jwks_url?: string;
+ oidc_connect_scopes?: string;
+ oidc_connect_redirect_url?: string;
+ oidc_connect_frontend_redirect_url?: string;
+ oidc_connect_token_auth_method?: string;
+ oidc_connect_use_pkce?: boolean;
+ oidc_connect_validate_id_token?: boolean;
+ oidc_connect_allowed_signing_algs?: string;
+ oidc_connect_clock_skew_seconds?: number;
+ oidc_connect_require_email_verified?: boolean;
+ oidc_connect_userinfo_email_path?: string;
+ oidc_connect_userinfo_id_path?: string;
+ oidc_connect_userinfo_username_path?: string;
+ enable_model_fallback?: boolean;
+ fallback_model_anthropic?: string;
+ fallback_model_openai?: string;
+ fallback_model_gemini?: string;
+ fallback_model_antigravity?: string;
+ enable_identity_patch?: boolean;
+ identity_patch_prompt?: string;
+ ops_monitoring_enabled?: boolean;
+ ops_realtime_monitoring_enabled?: boolean;
+ ops_query_mode_default?: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds?: number;
+ min_claude_code_version?: string;
+ max_claude_code_version?: string;
+ allow_ungrouped_key_scheduling?: boolean;
+ enable_fingerprint_unification?: boolean;
+ enable_metadata_passthrough?: boolean;
+ enable_cch_signing?: boolean;
// Payment configuration
- payment_enabled?: boolean
- payment_min_amount?: number
- payment_max_amount?: number
- payment_daily_limit?: number
- payment_order_timeout_minutes?: number
- payment_max_pending_orders?: number
- payment_enabled_types?: string[]
- payment_balance_disabled?: boolean
- payment_balance_recharge_multiplier?: number
- payment_recharge_fee_rate?: number
- payment_load_balance_strategy?: string
- payment_product_name_prefix?: string
- payment_product_name_suffix?: string
- payment_help_image_url?: string
- payment_help_text?: string
- payment_cancel_rate_limit_enabled?: boolean
- payment_cancel_rate_limit_max?: number
- payment_cancel_rate_limit_window?: number
- payment_cancel_rate_limit_unit?: string
- payment_cancel_rate_limit_window_mode?: string
+ payment_enabled?: boolean;
+ payment_min_amount?: number;
+ payment_max_amount?: number;
+ payment_daily_limit?: number;
+ payment_order_timeout_minutes?: number;
+ payment_max_pending_orders?: number;
+ payment_enabled_types?: string[];
+ payment_balance_disabled?: boolean;
+ payment_balance_recharge_multiplier?: number;
+ payment_recharge_fee_rate?: number;
+ payment_load_balance_strategy?: string;
+ payment_product_name_prefix?: string;
+ payment_product_name_suffix?: string;
+ payment_help_image_url?: string;
+ payment_help_text?: string;
+ payment_cancel_rate_limit_enabled?: boolean;
+ payment_cancel_rate_limit_max?: number;
+ payment_cancel_rate_limit_window?: number;
+ payment_cancel_rate_limit_unit?: string;
+ payment_cancel_rate_limit_window_mode?: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
// Balance & quota notification
- balance_low_notify_enabled?: boolean
- balance_low_notify_threshold?: number
- balance_low_notify_recharge_url?: string
- account_quota_notify_enabled?: boolean
- account_quota_notify_emails?: NotifyEmailEntry[]
+ balance_low_notify_enabled?: boolean;
+ balance_low_notify_threshold?: number;
+ balance_low_notify_recharge_url?: string;
+ account_quota_notify_enabled?: boolean;
+ account_quota_notify_emails?: NotifyEmailEntry[];
}
/**
@@ -258,8 +625,8 @@ export interface UpdateSettingsRequest {
* @returns System settings
*/
export async function getSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings')
- return data
+ const { data } = await apiClient.get("/admin/settings");
+ return data;
}
/**
@@ -267,20 +634,25 @@ export async function getSettings(): Promise {
* @param settings - Partial settings to update
* @returns Updated settings
*/
-export async function updateSettings(settings: UpdateSettingsRequest): Promise {
- const { data } = await apiClient.put('/admin/settings', settings)
- return data
+export async function updateSettings(
+ settings: UpdateSettingsRequest,
+): Promise {
+ const { data } = await apiClient.put(
+ "/admin/settings",
+ settings,
+ );
+ return data;
}
/**
* Test SMTP connection request
*/
export interface TestSmtpRequest {
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_use_tls: boolean;
}
/**
@@ -288,23 +660,28 @@ export interface TestSmtpRequest {
* @param config - SMTP configuration to test
* @returns Test result message
*/
-export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> {
- const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config)
- return data
+export async function testSmtpConnection(
+ config: TestSmtpRequest,
+): Promise<{ message: string }> {
+ const { data } = await apiClient.post<{ message: string }>(
+ "/admin/settings/test-smtp",
+ config,
+ );
+ return data;
}
/**
* Send test email request
*/
export interface SendTestEmailRequest {
- email: string
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ email: string;
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
}
/**
@@ -312,20 +689,22 @@ export interface SendTestEmailRequest {
* @param request - Email address and SMTP config
* @returns Test result message
*/
-export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> {
+export async function sendTestEmail(
+ request: SendTestEmailRequest,
+): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>(
- '/admin/settings/send-test-email',
- request
- )
- return data
+ "/admin/settings/send-test-email",
+ request,
+ );
+ return data;
}
/**
* Admin API Key status response
*/
export interface AdminApiKeyStatus {
- exists: boolean
- masked_key: string
+ exists: boolean;
+ masked_key: string;
}
/**
@@ -333,8 +712,10 @@ export interface AdminApiKeyStatus {
* @returns Status indicating if key exists and masked version
*/
export async function getAdminApiKey(): Promise {
- const { data } = await apiClient.get('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
/**
@@ -342,8 +723,10 @@ export async function getAdminApiKey(): Promise {
* @returns The new full API key (only shown once)
*/
export async function regenerateAdminApiKey(): Promise<{ key: string }> {
- const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate')
- return data
+ const { data } = await apiClient.post<{ key: string }>(
+ "/admin/settings/admin-api-key/regenerate",
+ );
+ return data;
}
/**
@@ -351,8 +734,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> {
* @returns Success message
*/
export async function deleteAdminApiKey(): Promise<{ message: string }> {
- const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.delete<{ message: string }>(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
// ==================== Overload Cooldown Settings ====================
@@ -361,23 +746,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> {
* Overload cooldown settings interface (529 handling)
*/
export interface OverloadCooldownSettings {
- enabled: boolean
- cooldown_minutes: number
+ enabled: boolean;
+ cooldown_minutes: number;
}
export async function getOverloadCooldownSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/overload-cooldown')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/overload-cooldown",
+ );
+ return data;
}
export async function updateOverloadCooldownSettings(
- settings: OverloadCooldownSettings
+ settings: OverloadCooldownSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/overload-cooldown',
- settings
- )
- return data
+ "/admin/settings/overload-cooldown",
+ settings,
+ );
+ return data;
}
// ==================== Stream Timeout Settings ====================
@@ -386,11 +773,11 @@ export async function updateOverloadCooldownSettings(
* Stream timeout settings interface
*/
export interface StreamTimeoutSettings {
- enabled: boolean
- action: 'temp_unsched' | 'error' | 'none'
- temp_unsched_minutes: number
- threshold_count: number
- threshold_window_minutes: number
+ enabled: boolean;
+ action: "temp_unsched" | "error" | "none";
+ temp_unsched_minutes: number;
+ threshold_count: number;
+ threshold_window_minutes: number;
}
/**
@@ -398,8 +785,10 @@ export interface StreamTimeoutSettings {
* @returns Stream timeout settings
*/
export async function getStreamTimeoutSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/stream-timeout')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/stream-timeout",
+ );
+ return data;
}
/**
@@ -408,13 +797,13 @@ export async function getStreamTimeoutSettings(): Promise
* @returns Updated settings
*/
export async function updateStreamTimeoutSettings(
- settings: StreamTimeoutSettings
+ settings: StreamTimeoutSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/stream-timeout',
- settings
- )
- return data
+ "/admin/settings/stream-timeout",
+ settings,
+ );
+ return data;
}
// ==================== Rectifier Settings ====================
@@ -423,11 +812,11 @@ export async function updateStreamTimeoutSettings(
* Rectifier settings interface
*/
export interface RectifierSettings {
- enabled: boolean
- thinking_signature_enabled: boolean
- thinking_budget_enabled: boolean
- apikey_signature_enabled: boolean
- apikey_signature_patterns: string[]
+ enabled: boolean;
+ thinking_signature_enabled: boolean;
+ thinking_budget_enabled: boolean;
+ apikey_signature_enabled: boolean;
+ apikey_signature_patterns: string[];
}
/**
@@ -435,8 +824,10 @@ export interface RectifierSettings {
* @returns Rectifier settings
*/
export async function getRectifierSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/rectifier')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/rectifier",
+ );
+ return data;
}
/**
@@ -445,13 +836,13 @@ export async function getRectifierSettings(): Promise {
* @returns Updated settings
*/
export async function updateRectifierSettings(
- settings: RectifierSettings
+ settings: RectifierSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/rectifier',
- settings
- )
- return data
+ "/admin/settings/rectifier",
+ settings,
+ );
+ return data;
}
// ==================== Beta Policy Settings ====================
@@ -460,20 +851,20 @@ export async function updateRectifierSettings(
* Beta policy rule interface
*/
export interface BetaPolicyRule {
- beta_token: string
- action: 'pass' | 'filter' | 'block'
- scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
- error_message?: string
- model_whitelist?: string[]
- fallback_action?: 'pass' | 'filter' | 'block'
- fallback_error_message?: string
+ beta_token: string;
+ action: "pass" | "filter" | "block";
+ scope: "all" | "oauth" | "apikey" | "bedrock";
+ error_message?: string;
+ model_whitelist?: string[];
+ fallback_action?: "pass" | "filter" | "block";
+ fallback_error_message?: string;
}
/**
* Beta policy settings interface
*/
export interface BetaPolicySettings {
- rules: BetaPolicyRule[]
+ rules: BetaPolicyRule[];
}
/**
@@ -481,8 +872,10 @@ export interface BetaPolicySettings {
* @returns Beta policy settings
*/
export async function getBetaPolicySettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/beta-policy')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/beta-policy",
+ );
+ return data;
}
/**
@@ -491,70 +884,73 @@ export async function getBetaPolicySettings(): Promise {
* @returns Updated settings
*/
export async function updateBetaPolicySettings(
- settings: BetaPolicySettings
+ settings: BetaPolicySettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/beta-policy',
- settings
- )
- return data
+ "/admin/settings/beta-policy",
+ settings,
+ );
+ return data;
}
// --- Web Search Emulation Config ---
export interface WebSearchProviderConfig {
- type: 'brave' | 'tavily'
- api_key: string
- api_key_configured: boolean
- quota_limit: number | null
- subscribed_at: number | null
- quota_used?: number
- proxy_id: number | null
- expires_at: number | null
+ type: "brave" | "tavily";
+ api_key: string;
+ api_key_configured: boolean;
+ quota_limit: number | null;
+ subscribed_at: number | null;
+ quota_used?: number;
+ proxy_id: number | null;
+ expires_at: number | null;
}
export interface WebSearchEmulationConfig {
- enabled: boolean
- providers: WebSearchProviderConfig[]
+ enabled: boolean;
+ providers: WebSearchProviderConfig[];
}
export interface WebSearchTestResult {
- provider: string
- results: { url: string; title: string; snippet: string; page_age?: string }[]
- query: string
+ provider: string;
+ results: { url: string; title: string; snippet: string; page_age?: string }[];
+ query: string;
}
export async function getWebSearchEmulationConfig(): Promise {
const { data } = await apiClient.get(
- '/admin/settings/web-search-emulation'
- )
- return data
+ "/admin/settings/web-search-emulation",
+ );
+ return data;
}
export async function updateWebSearchEmulationConfig(
- config: WebSearchEmulationConfig
+ config: WebSearchEmulationConfig,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/web-search-emulation',
- config
- )
- return data
+ "/admin/settings/web-search-emulation",
+ config,
+ );
+ return data;
}
export async function testWebSearchEmulation(
- query: string
+ query: string,
): Promise {
const { data } = await apiClient.post(
- '/admin/settings/web-search-emulation/test',
- { query }
- )
- return data
+ "/admin/settings/web-search-emulation/test",
+ { query },
+ );
+ return data;
}
-export async function resetWebSearchUsage(
- payload: { provider_type: string }
-): Promise {
- await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload)
+export async function resetWebSearchUsage(payload: {
+ provider_type: string;
+}): Promise {
+ await apiClient.post(
+ "/admin/settings/web-search-emulation/reset-usage",
+ payload,
+ );
}
export const settingsAPI = {
@@ -576,7 +972,7 @@ export const settingsAPI = {
getWebSearchEmulationConfig,
updateWebSearchEmulationConfig,
testWebSearchEmulation,
- resetWebSearchUsage
-}
+ resetWebSearchUsage,
+};
-export default settingsAPI
+export default settingsAPI;
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index 39cb1dfa..3c75a6c4 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -6,6 +6,44 @@
import { apiClient } from '../client'
import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types'
+export interface AdminBindAuthIdentityChannelRequest {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata?: Record | null
+}
+
+export interface AdminBindAuthIdentityRequest {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string | null
+ metadata?: Record | null
+ channel?: AdminBindAuthIdentityChannelRequest
+}
+
+export interface AdminBoundAuthIdentityChannel {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+}
+
+export interface AdminBoundAuthIdentity {
+ user_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ verified_at?: string | null
+ issuer?: string | null
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ channel?: AdminBoundAuthIdentityChannel | null
+}
+
/**
* List all users with pagination
* @param page - Page number (default: 1)
@@ -248,6 +286,17 @@ export async function replaceGroup(
return data
}
+export async function bindUserAuthIdentity(
+ userId: number,
+ input: AdminBindAuthIdentityRequest
+): Promise {
+ const { data } = await apiClient.post(
+ `/admin/users/${userId}/auth-identities`,
+ input
+ )
+ return data
+}
+
export const usersAPI = {
list,
getById,
@@ -260,7 +309,8 @@ export const usersAPI = {
getUserApiKeys,
getUserUsageStats,
getUserBalanceHistory,
- replaceGroup
+ replaceGroup,
+ bindUserAuthIdentity
}
export default usersAPI
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 837c4f4c..f49f3a1f 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -186,6 +186,108 @@ export interface RefreshTokenResponse {
token_type: string
}
+export interface OAuthTokenResponse {
+ access_token: string
+ refresh_token?: string
+ expires_in?: number
+ token_type?: string
+}
+
+export interface PendingOAuthBindLoginResponse extends Partial {
+ auth_result?: string
+ redirect?: string
+ error?: string
+ requires_2fa?: boolean
+ temp_token?: string
+ user_email_masked?: string
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}
+
+export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse
+
+export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {
+ auth_result?: string
+}
+
+export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse {
+ auth_result?: string
+ provider?: string
+ redirect?: string
+}
+
+export type OAuthCompletionKind = 'login' | 'bind'
+
+export interface OAuthAdoptionDecision {
+ adoptDisplayName?: boolean
+ adoptAvatar?: boolean
+}
+
+function serializeOAuthAdoptionDecision(
+ decision?: OAuthAdoptionDecision
+): Record {
+ const payload: Record = {}
+
+ if (typeof decision?.adoptDisplayName === 'boolean') {
+ payload.adopt_display_name = decision.adoptDisplayName
+ }
+ if (typeof decision?.adoptAvatar === 'boolean') {
+ payload.adopt_avatar = decision.adoptAvatar
+ }
+
+ return payload
+}
+
+export function isOAuthLoginCompletion(
+ completion: Partial
+): completion is OAuthTokenResponse {
+ return typeof completion.access_token === 'string' && completion.access_token.trim().length > 0
+}
+
+export function getOAuthCompletionKind(
+ completion: Partial
+): OAuthCompletionKind {
+ return isOAuthLoginCompletion(completion) ? 'login' : 'bind'
+}
+
+export function getPendingOAuthBindLoginKind(
+ completion: PendingOAuthBindLoginResponse
+): OAuthCompletionKind {
+ return getOAuthCompletionKind(completion)
+}
+
+export function isPendingOAuthCreateAccountRequired(
+ completion: Pick
+): boolean {
+ return completion.error === 'invitation_required'
+}
+
+export function hasPendingOAuthSuggestedProfile(
+ completion: Pick<
+ PendingOAuthBindLoginResponse,
+ 'suggested_display_name' | 'suggested_avatar_url'
+ >
+): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+export function persistOAuthTokenContext(tokens: Partial): void {
+ if (tokens.refresh_token) {
+ setRefreshToken(tokens.refresh_token)
+ }
+ if (tokens.expires_in) {
+ setTokenExpiresAt(tokens.expires_in)
+ }
+}
+
+export async function prepareOAuthBindAccessTokenCookie(): Promise {
+ if (!getAuthToken()) {
+ return
+ }
+ await apiClient.post('/auth/oauth/bind-token')
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -234,6 +336,116 @@ export async function getPublicSettings(): Promise {
return data
}
+export type WeChatOAuthMode = 'open' | 'mp'
+export type WeChatOAuthUnavailableReason =
+ | 'not_configured'
+ | 'capability_unknown'
+ | 'external_browser_required'
+ | 'wechat_browser_required'
+ | 'native_app_required'
+
+export interface ResolvedWeChatOAuthStart {
+ mode: WeChatOAuthMode | null
+ openEnabled: boolean
+ mpEnabled: boolean
+ mobileEnabled: boolean
+ isWeChatBrowser: boolean
+ unavailableReason: WeChatOAuthUnavailableReason | null
+}
+
+export type WeChatOAuthPublicSettings = {
+ wechat_oauth_enabled?: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+ wechat_oauth_mobile_enabled?: boolean
+}
+
+export function isWeChatWebOAuthEnabled(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): boolean {
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const hasExplicitCapabilities =
+ typeof settings?.wechat_oauth_open_enabled === 'boolean' ||
+ typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+
+ if (!hasExplicitCapabilities) {
+ return legacyEnabled
+ }
+
+ return settings?.wechat_oauth_open_enabled === true || settings?.wechat_oauth_mp_enabled === true
+}
+
+export function hasExplicitWeChatOAuthCapabilities(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): settings is WeChatOAuthPublicSettings & {
+ wechat_oauth_open_enabled: boolean
+ wechat_oauth_mp_enabled: boolean
+} {
+ return typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ && typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+}
+
+export function resolveWeChatOAuthStart(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ ? settings.wechat_oauth_open_enabled
+ : legacyEnabled
+ const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+ ? settings.wechat_oauth_mp_enabled
+ : legacyEnabled
+ const mobileEnabled = typeof settings?.wechat_oauth_mobile_enabled === 'boolean'
+ ? settings.wechat_oauth_mobile_enabled
+ : false
+
+ if (isWeChatBrowser) {
+ if (mpEnabled) {
+ return { mode: 'mp', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (openEnabled) {
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+ }
+
+ if (openEnabled) {
+ return { mode: 'open', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (mpEnabled) {
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+}
+
+export function resolveWeChatOAuthStartStrict(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string,
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+
+ if (!hasExplicitWeChatOAuthCapabilities(settings)) {
+ return {
+ mode: null,
+ openEnabled: false,
+ mpEnabled: false,
+ mobileEnabled: false,
+ isWeChatBrowser,
+ unavailableReason: 'capability_unknown',
+ }
+ }
+
+ return resolveWeChatOAuthStart(settings, normalizedUserAgent)
+}
+
/**
* Send verification code to email
* @param request - Email and optional Turnstile token
@@ -246,6 +458,16 @@ export async function sendVerifyCode(
return data
}
+export async function sendPendingOAuthVerifyCode(
+ request: SendVerifyCodeRequest
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/send-verify-code',
+ request
+ )
+ return data
+}
+
/**
* Validate promo code response
*/
@@ -337,48 +559,87 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
- const { data } = await apiClient.post<{
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
- }>('/auth/oauth/linuxdo/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
- invitation_code: invitationCode
- })
- return data
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingLinuxDoOAuthAccount(invitationCode, decision)
}
/**
* Complete OIDC OAuth registration by supplying an invitation code
- * @param pendingOAuthToken - Short-lived JWT from the OAuth callback
* @param invitationCode - Invitation code entered by the user
* @returns Token pair on success
*/
export async function completeOIDCOAuthRegistration(
- pendingOAuthToken: string,
- invitationCode: string
-): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
- const { data } = await apiClient.post<{
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
- }>('/auth/oauth/oidc/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
- invitation_code: invitationCode
- })
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOIDCOAuthAccount(invitationCode, decision)
+}
+
+export async function completeWeChatOAuthRegistration(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingWeChatOAuthAccount(invitationCode, decision)
+}
+
+async function createPendingOAuthAccount(
+ provider: 'linuxdo' | 'oidc' | 'wechat',
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ `/auth/oauth/${provider}/complete-registration`,
+ {
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
+ }
+ )
return data
}
+export async function createPendingLinuxDoOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('linuxdo', invitationCode, decision)
+}
+
+export async function createPendingOIDCOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('oidc', invitationCode, decision)
+}
+
+export async function createPendingWeChatOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('wechat', invitationCode, decision)
+}
+
+export async function completePendingOAuthBindLogin(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/exchange',
+ serializeOAuthAdoptionDecision(decision)
+ )
+ return data
+}
+
+export async function exchangePendingOAuthCompletion(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return completePendingOAuthBindLogin(decision)
+}
+
export const authAPI = {
login,
login2FA,
@@ -396,14 +657,24 @@ export const authAPI = {
clearAuthToken,
getPublicSettings,
sendVerifyCode,
+ sendPendingOAuthVerifyCode,
validatePromoCode,
validateInvitationCode,
forgotPassword,
resetPassword,
refreshToken,
revokeAllSessions,
+ getPendingOAuthBindLoginKind,
+ isPendingOAuthCreateAccountRequired,
+ hasPendingOAuthSuggestedProfile,
+ completePendingOAuthBindLogin,
+ createPendingLinuxDoOAuthAccount,
+ createPendingOIDCOAuthAccount,
+ createPendingWeChatOAuthAccount,
+ exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
- completeOIDCOAuthRegistration
+ completeOIDCOAuthRegistration,
+ completeWeChatOAuthRegistration
}
export default authAPI
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 8a586902..54ea4520 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -13,6 +13,7 @@ const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || '/api/v1'
export const apiClient: AxiosInstance = axios.create({
baseURL: API_BASE_URL,
+ withCredentials: true,
timeout: 30000,
headers: {
'Content-Type': 'application/json'
diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts
index 5cedb107..92b0ec90 100644
--- a/frontend/src/api/payment.ts
+++ b/frontend/src/api/payment.ts
@@ -67,11 +67,16 @@ export const paymentAPI = {
return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo })
},
- /** Verify order payment status without auth (public endpoint for result page) */
+ /** Legacy-compatible public order lookup by out_trade_no */
verifyOrderPublic(outTradeNo: string) {
return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo })
},
+ /** Resolve an order from a signed resume token without auth */
+ resolveOrderPublicByResumeToken(resumeToken: string) {
+ return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken })
+ },
+
/** Request a refund for a completed order */
requestRefund(id: number, data: { reason: string }) {
return apiClient.post(`/payment/orders/${id}/refund-request`, data)
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index cd648270..fd3cedb9 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -4,7 +4,12 @@
*/
import { apiClient } from './client'
-import type { User, ChangePasswordRequest, NotifyEmailEntry } from '@/types'
+import {
+ resolveWeChatOAuthStartStrict,
+ prepareOAuthBindAccessTokenCookie,
+ type WeChatOAuthPublicSettings,
+} from './auth'
+import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types'
/**
* Get current user profile
@@ -22,6 +27,7 @@ export async function getProfile(): Promise {
*/
export async function updateProfile(profile: {
username?: string
+ avatar_url?: string | null
balance_notify_enabled?: boolean
balance_notify_threshold?: number | null
balance_notify_extra_emails?: NotifyEmailEntry[]
@@ -83,6 +89,85 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi
return data
}
+export async function sendEmailBindingCode(email: string): Promise {
+ await apiClient.post('/user/account-bindings/email/send-code', { email })
+}
+
+export async function bindEmailIdentity(payload: {
+ email: string
+ verify_code: string
+ password: string
+}): Promise {
+ const { data } = await apiClient.post('/user/account-bindings/email', payload)
+ return data
+}
+
+export async function unbindAuthIdentity(provider: BindableOAuthProvider): Promise {
+ const { data } = await apiClient.delete(`/user/account-bindings/${provider}`)
+ return data
+}
+
+export type BindableOAuthProvider = Exclude
+
+interface BuildOAuthBindingStartURLOptions {
+ redirectTo?: string
+ wechatOAuthSettings?: WeChatOAuthPublicSettings | null
+}
+
+export function resolveWeChatOAuthMode(): 'open' | 'mp' {
+ if (typeof navigator === 'undefined') {
+ return 'open'
+ }
+ return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open'
+}
+
+function resolveWeChatOAuthBindingMode(
+ settings?: WeChatOAuthPublicSettings | null
+): 'open' | 'mp' | null {
+ if (settings) {
+ return resolveWeChatOAuthStartStrict(settings).mode
+ }
+ return resolveWeChatOAuthMode()
+}
+
+export function buildOAuthBindingStartURL(
+ provider: BindableOAuthProvider,
+ options: BuildOAuthBindingStartURLOptions = {}
+): string | null {
+ const redirectTo = options.redirectTo?.trim() || '/profile'
+ const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1'
+ const normalized = apiBase.replace(/\/$/, '')
+ const params = new URLSearchParams({
+ redirect: redirectTo,
+ intent: 'bind_current_user'
+ })
+
+ if (provider === 'wechat') {
+ const mode = resolveWeChatOAuthBindingMode(options.wechatOAuthSettings)
+ if (!mode) {
+ return null
+ }
+ params.set('mode', mode)
+ }
+
+ return `${normalized}/auth/oauth/${provider}/bind/start?${params.toString()}`
+}
+
+export async function startOAuthBinding(
+ provider: BindableOAuthProvider,
+ options: BuildOAuthBindingStartURLOptions = {}
+): Promise {
+ if (typeof window === 'undefined') {
+ return
+ }
+ const startURL = buildOAuthBindingStartURL(provider, options)
+ if (!startURL) {
+ return
+ }
+ await prepareOAuthBindAccessTokenCookie()
+ window.location.href = startURL
+}
+
export const userAPI = {
getProfile,
updateProfile,
@@ -90,7 +175,12 @@ export const userAPI = {
sendNotifyEmailCode,
verifyNotifyEmail,
removeNotifyEmail,
- toggleNotifyEmail
+ toggleNotifyEmail,
+ sendEmailBindingCode,
+ bindEmailIdentity,
+ unbindAuthIdentity,
+ buildOAuthBindingStartURL,
+ startOAuthBinding
}
export default userAPI
diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue
index fc2f7d0c..dd38a49f 100644
--- a/frontend/src/components/account/AccountStatusIndicator.vue
+++ b/frontend/src/components/account/AccountStatusIndicator.vue
@@ -284,6 +284,16 @@ const hasError = computed(() => {
return props.account.status === 'error'
})
+const isQuotaExceeded = computed(() => {
+ const exceeded = (used?: number | null, limit?: number | null) =>
+ typeof limit === 'number' && limit > 0 && typeof used === 'number' && used >= limit
+ return (
+ exceeded(props.account.quota_used, props.account.quota_limit) ||
+ exceeded(props.account.quota_daily_used, props.account.quota_daily_limit) ||
+ exceeded(props.account.quota_weekly_used, props.account.quota_weekly_limit)
+ )
+})
+
// Computed: countdown text for rate limit (429)
const rateLimitCountdown = computed(() => {
return formatCountdown(props.account.rate_limit_reset_at)
@@ -307,19 +317,16 @@ const statusClass = computed(() => {
if (isTempUnschedulable.value) {
return 'badge-warning'
}
+ if (props.account.status !== 'active') {
+ return props.account.status === 'error' ? 'badge-danger' : 'badge-gray'
+ }
+ if (isQuotaExceeded.value) {
+ return 'badge-warning'
+ }
if (!props.account.schedulable) {
return 'badge-gray'
}
- switch (props.account.status) {
- case 'active':
- return 'badge-success'
- case 'inactive':
- return 'badge-gray'
- case 'error':
- return 'badge-danger'
- default:
- return 'badge-gray'
- }
+ return 'badge-success'
})
// Computed: status text
@@ -330,6 +337,12 @@ const statusText = computed(() => {
if (isTempUnschedulable.value) {
return t('admin.accounts.status.tempUnschedulable')
}
+ if (props.account.status !== 'active') {
+ return t(`admin.accounts.status.${props.account.status}`)
+ }
+ if (isQuotaExceeded.value) {
+ return t('admin.accounts.status.quotaExceeded')
+ }
if (!props.account.schedulable) {
return t('admin.accounts.status.paused')
}
diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue
index b0ce7e70..ae0fd9a7 100644
--- a/frontend/src/components/account/AccountTestModal.vue
+++ b/frontend/src/components/account/AccountTestModal.vue
@@ -55,12 +55,12 @@
/>
-
+
@@ -122,25 +122,49 @@
- {{ t('admin.accounts.geminiImagePreview') }}
+ {{ t('admin.accounts.imagePreview') }}
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -152,8 +176,8 @@
{{
- supportsGeminiImageTest
- ? t('admin.accounts.geminiImageTestMode')
+ supportsImageTest
+ ? t('admin.accounts.imageTestMode')
: t('admin.accounts.testPrompt')
}}
@@ -250,6 +274,7 @@ const testPrompt = ref('')
const loadingModels = ref(false)
let abortController: AbortController | null = null
const generatedImages = ref
([])
+const previewImageUrl = ref('')
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
const modelID = selectedModelId.value.toLowerCase()
@@ -258,6 +283,14 @@ const supportsGeminiImageTest = computed(() => {
return props.account?.platform === 'gemini' || (props.account?.platform === 'antigravity' && props.account?.type === 'apikey')
})
+const supportsOpenAIImageTest = computed(() => {
+ const modelID = selectedModelId.value.toLowerCase()
+ if (!modelID.startsWith('gpt-image-')) return false
+ return props.account?.platform === 'openai'
+})
+
+const supportsImageTest = computed(() => supportsGeminiImageTest.value || supportsOpenAIImageTest.value)
+
const sortTestModels = (models: ClaudeModel[]) => {
const priorityMap = new Map(prioritizedGeminiModels.map((id, index) => [id, index]))
@@ -284,8 +317,8 @@ watch(
)
watch(selectedModelId, () => {
- if (supportsGeminiImageTest.value && !testPrompt.value.trim()) {
- testPrompt.value = t('admin.accounts.geminiImagePromptDefault')
+ if (supportsImageTest.value && !testPrompt.value.trim()) {
+ testPrompt.value = t('admin.accounts.imagePromptDefault')
}
})
@@ -325,6 +358,7 @@ const resetState = () => {
streamingContent.value = ''
errorMessage.value = ''
generatedImages.value = []
+ previewImageUrl.value = ''
}
const handleClose = () => {
@@ -377,7 +411,7 @@ const startTest = async () => {
},
body: JSON.stringify({
model_id: selectedModelId.value,
- prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : ''
+ prompt: supportsImageTest.value ? testPrompt.value.trim() : ''
}),
signal: abortController.signal
})
@@ -444,8 +478,8 @@ const handleEvent = (event: {
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
}
addLine(
- supportsGeminiImageTest.value
- ? t('admin.accounts.sendingGeminiImageRequest')
+ supportsImageTest.value
+ ? t('admin.accounts.sendingImageRequest')
: t('admin.accounts.sendingTestMessage'),
'text-gray-400'
)
@@ -466,7 +500,7 @@ const handleEvent = (event: {
url: event.image_url,
mimeType: event.mime_type
})
- addLine(t('admin.accounts.geminiImageReceived', { count: generatedImages.value.length }), 'text-purple-300')
+ addLine(t('admin.accounts.imageReceived', { count: generatedImages.value.length }), 'text-purple-300')
}
break
@@ -501,3 +535,14 @@ const copyOutput = () => {
copyToClipboard(text, t('admin.accounts.outputCopied'))
}
+
+
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 1da32e2c..59ca0b9c 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -52,6 +52,10 @@
v-model="editApiKey"
type="password"
class="input font-mono"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
:placeholder="
account.platform === 'openai'
? 'sk-proj-...'
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue
index b0ce7e70..ae0fd9a7 100644
--- a/frontend/src/components/admin/account/AccountTestModal.vue
+++ b/frontend/src/components/admin/account/AccountTestModal.vue
@@ -55,12 +55,12 @@
/>
-
+
@@ -122,25 +122,49 @@
- {{ t('admin.accounts.geminiImagePreview') }}
+ {{ t('admin.accounts.imagePreview') }}
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -152,8 +176,8 @@
{{
- supportsGeminiImageTest
- ? t('admin.accounts.geminiImageTestMode')
+ supportsImageTest
+ ? t('admin.accounts.imageTestMode')
: t('admin.accounts.testPrompt')
}}
@@ -250,6 +274,7 @@ const testPrompt = ref('')
const loadingModels = ref(false)
let abortController: AbortController | null = null
const generatedImages = ref
([])
+const previewImageUrl = ref('')
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
const modelID = selectedModelId.value.toLowerCase()
@@ -258,6 +283,14 @@ const supportsGeminiImageTest = computed(() => {
return props.account?.platform === 'gemini' || (props.account?.platform === 'antigravity' && props.account?.type === 'apikey')
})
+const supportsOpenAIImageTest = computed(() => {
+ const modelID = selectedModelId.value.toLowerCase()
+ if (!modelID.startsWith('gpt-image-')) return false
+ return props.account?.platform === 'openai'
+})
+
+const supportsImageTest = computed(() => supportsGeminiImageTest.value || supportsOpenAIImageTest.value)
+
const sortTestModels = (models: ClaudeModel[]) => {
const priorityMap = new Map(prioritizedGeminiModels.map((id, index) => [id, index]))
@@ -284,8 +317,8 @@ watch(
)
watch(selectedModelId, () => {
- if (supportsGeminiImageTest.value && !testPrompt.value.trim()) {
- testPrompt.value = t('admin.accounts.geminiImagePromptDefault')
+ if (supportsImageTest.value && !testPrompt.value.trim()) {
+ testPrompt.value = t('admin.accounts.imagePromptDefault')
}
})
@@ -325,6 +358,7 @@ const resetState = () => {
streamingContent.value = ''
errorMessage.value = ''
generatedImages.value = []
+ previewImageUrl.value = ''
}
const handleClose = () => {
@@ -377,7 +411,7 @@ const startTest = async () => {
},
body: JSON.stringify({
model_id: selectedModelId.value,
- prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : ''
+ prompt: supportsImageTest.value ? testPrompt.value.trim() : ''
}),
signal: abortController.signal
})
@@ -444,8 +478,8 @@ const handleEvent = (event: {
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
}
addLine(
- supportsGeminiImageTest.value
- ? t('admin.accounts.sendingGeminiImageRequest')
+ supportsImageTest.value
+ ? t('admin.accounts.sendingImageRequest')
: t('admin.accounts.sendingTestMessage'),
'text-gray-400'
)
@@ -466,7 +500,7 @@ const handleEvent = (event: {
url: event.image_url,
mimeType: event.mime_type
})
- addLine(t('admin.accounts.geminiImageReceived', { count: generatedImages.value.length }), 'text-purple-300')
+ addLine(t('admin.accounts.imageReceived', { count: generatedImages.value.length }), 'text-purple-300')
}
break
@@ -501,3 +535,14 @@ const copyOutput = () => {
copyToClipboard(text, t('admin.accounts.outputCopied'))
}
+
+
diff --git a/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts b/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts
index 801eab02..2dc7d504 100644
--- a/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts
+++ b/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts
@@ -24,13 +24,13 @@ vi.mock('@/composables/useClipboard', () => ({
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual('vue-i18n')
const messages: Record = {
- 'admin.accounts.geminiImagePromptDefault': 'Generate a cute orange cat astronaut sticker on a clean pastel background.'
+ 'admin.accounts.imagePromptDefault': 'Generate a cute orange cat astronaut sticker on a clean pastel background.'
}
return {
...actual,
useI18n: () => ({
t: (key: string, params?: Record) => {
- if (key === 'admin.accounts.geminiImageReceived' && params?.count) {
+ if (key === 'admin.accounts.imageReceived' && params?.count) {
return `received-${params.count}`
}
return messages[key] || key
@@ -140,7 +140,7 @@ describe('AccountTestModal', () => {
prompt: 'draw a tiny orange cat astronaut'
})
- const preview = wrapper.find('img[alt="gemini-test-image-1"]')
+ const preview = wrapper.find('img[alt="test-image-1"]')
expect(preview.exists()).toBe(true)
expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD')
})
diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
index bf79bea2..41b2e63c 100644
--- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
+++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
@@ -166,7 +166,7 @@
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/TotpLoginModal.vue b/frontend/src/components/auth/TotpLoginModal.vue
index 03fa718d..0ae2f482 100644
--- a/frontend/src/components/auth/TotpLoginModal.vue
+++ b/frontend/src/components/auth/TotpLoginModal.vue
@@ -47,11 +47,6 @@
-
-
- {{ error }}
-
-
import { ref, watch, nextTick, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
+import { useAppStore } from '@/stores'
defineProps<{
tempToken: string
@@ -81,9 +77,9 @@ const emit = defineEmits<{
}>()
const { t } = useI18n()
+const appStore = useAppStore()
const verifying = ref(false)
-const error = ref('')
const code = ref(['', '', '', '', '', ''])
const inputRefs = ref<(HTMLInputElement | null)[]>([])
@@ -100,7 +96,9 @@ watch(
defineExpose({
setVerifying: (value: boolean) => { verifying.value = value },
setError: (message: string) => {
- error.value = message
+ if (message) {
+ appStore.showError(message)
+ }
code.value = ['', '', '', '', '', '']
// Clear input DOM values
inputRefs.value.forEach(input => {
diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue
new file mode 100644
index 00000000..ce90738c
--- /dev/null
+++ b/frontend/src/components/auth/WechatOAuthSection.vue
@@ -0,0 +1,93 @@
+
+
+
+
+ W
+
+ {{ t('auth.oidc.signIn', { providerName }) }}
+
+
+
+ {{ disabledHint }}
+
+
+
+
+
+ {{ t('auth.oauthOrContinue') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
new file mode 100644
index 00000000..1e462e29
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
@@ -0,0 +1,205 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue'
+
+const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
+const getPublicSettings = vi.fn()
+const showError = vi.fn()
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettings(...args)
+ }
+})
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError
+ })
+}))
+
+describe('PendingOAuthCreateAccountForm', () => {
+ beforeEach(() => {
+ sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
+ getPublicSettings.mockReset()
+ showError.mockReset()
+ getPublicSettings.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
+ })
+
+ it('emits trimmed email, password, and verify code on submit', async () => {
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue(' 246810 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'user@example.com',
+ password: 'secret-123',
+ verifyCode: '246810'
+ }
+ ]
+ ])
+ })
+
+ it('renders action labels through i18n keys', () => {
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ expect(wrapper.text()).toContain('auth.createAccount')
+ expect(wrapper.text()).toContain('auth.alreadyHaveAccount')
+ })
+
+ it('shows and emits invitation code when invitation-only signup is enabled', async () => {
+ getPublicSettings.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'prefill@example.com',
+ password: 'secret-123',
+ verifyCode: '246810',
+ invitationCode: 'INVITE123'
+ }
+ ]
+ ])
+ })
+
+ it('sends a verify code for the trimmed email value', async () => {
+ sendPendingOAuthVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
+ email: 'user@example.com'
+ })
+ })
+
+ it('shows send-code failures via toast without rendering inline error text', async () => {
+ sendPendingOAuthVerifyCode.mockRejectedValue(new Error('send failed'))
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('user@example.com')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(showError).toHaveBeenCalledWith('send failed')
+ expect(wrapper.text()).not.toContain('send failed')
+ })
+
+ it('requires a turnstile token before sending a verify code when turnstile is enabled', async () => {
+ getPublicSettings.mockResolvedValue({
+ turnstile_enabled: true,
+ turnstile_site_key: 'site-key'
+ })
+ sendPendingOAuthVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ },
+ global: {
+ stubs: {
+ TurnstileWidget: {
+ template: 'verify '
+ }
+ }
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+
+ expect(wrapper.get('[data-testid="linuxdo-create-account-send-code"]').attributes('disabled')).toBeDefined()
+
+ await wrapper.get('[data-testid="turnstile-verify"]').trigger('click')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
+ email: 'user@example.com',
+ turnstile_token: 'turnstile-token'
+ })
+ })
+})
diff --git a/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts b/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts
new file mode 100644
index 00000000..06fbe397
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts
@@ -0,0 +1,41 @@
+import { mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import TotpLoginModal from '@/components/auth/TotpLoginModal.vue'
+
+const { showErrorMock } = vi.hoisted(() => ({
+ showErrorMock: vi.fn(),
+}))
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError: (...args: any[]) => showErrorMock(...args),
+ }),
+}))
+
+describe('TotpLoginModal', () => {
+ beforeEach(() => {
+ showErrorMock.mockReset()
+ })
+
+ it('sends verification errors to toast and does not render inline red text', async () => {
+ const wrapper = mount(TotpLoginModal, {
+ props: {
+ tempToken: 'temp-token',
+ userEmailMasked: 'u***@example.com',
+ },
+ })
+
+ ;(wrapper.vm as unknown as { setError: (message: string) => void }).setError('Invalid code')
+ await wrapper.vm.$nextTick()
+
+ expect(showErrorMock).toHaveBeenCalledWith('Invalid code')
+ expect(wrapper.text()).not.toContain('Invalid code')
+ expect(wrapper.find('.bg-red-50').exists()).toBe(false)
+ })
+})
diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
new file mode 100644
index 00000000..2f269e0b
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
@@ -0,0 +1,238 @@
+import { mount } from '@vue/test-utils'
+import { createPinia, setActivePinia } from 'pinia'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue'
+import { useAppStore } from '@/stores'
+import type { PublicSettings } from '@/types'
+
+const routeState = vi.hoisted(() => ({
+ query: {} as Record,
+}))
+
+const locationState = vi.hoisted(() => ({
+ current: { href: 'http://localhost/login' } as { href: string },
+}))
+
+let pinia: ReturnType
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ locale: { value: 'en' },
+ t: (key: string, params?: Record) => {
+ if (key === 'auth.wechatProviderName') {
+ return 'Mock WeChat'
+ }
+ if (key === 'auth.oidc.signIn') {
+ return `Continue with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oauthFlow.wechatSystemBrowserOnly') {
+ return 'MOCK-SYSTEM-BROWSER-ONLY'
+ }
+ if (key === 'auth.oauthFlow.wechatBrowserOnly') {
+ return 'MOCK-WECHAT-BROWSER-ONLY'
+ }
+ if (key === 'auth.oauthFlow.wechatNotConfigured') {
+ return 'MOCK-NOT-CONFIGURED'
+ }
+ if (key === 'auth.oauthOrContinue') {
+ return 'or continue'
+ }
+ return key
+ },
+ }),
+ }
+})
+
+type WeChatPublicSettings = PublicSettings & {
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+}
+
+function buildPublicSettings(overrides: Partial = {}): WeChatPublicSettings {
+ return {
+ registration_enabled: true,
+ email_verify_enabled: false,
+ force_email_on_third_party_signup: false,
+ registration_email_suffix_whitelist: [],
+ promo_code_enabled: true,
+ password_reset_enabled: false,
+ invitation_code_enabled: false,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ site_logo: '',
+ site_subtitle: '',
+ api_base_url: '/api/v1',
+ contact_info: '',
+ doc_url: '',
+ home_content: '',
+ hide_ccs_import_button: false,
+ payment_enabled: false,
+ table_default_page_size: 20,
+ table_page_size_options: [10, 20, 50, 100],
+ custom_menu_items: [],
+ custom_endpoints: [],
+ linuxdo_oauth_enabled: false,
+ wechat_oauth_enabled: true,
+ oidc_oauth_enabled: false,
+ oidc_oauth_provider_name: 'OIDC',
+ backend_mode_enabled: false,
+ version: 'test',
+ balance_low_notify_enabled: false,
+ account_quota_notify_enabled: false,
+ balance_low_notify_threshold: 0,
+ ...overrides,
+ }
+}
+
+function seedPublicSettings(overrides: Partial = {}): void {
+ const appStore = useAppStore()
+ const settings = buildPublicSettings(overrides)
+ appStore.cachedPublicSettings = settings
+ appStore.publicSettingsLoaded = true
+}
+
+describe('WechatOAuthSection', () => {
+ beforeEach(() => {
+ pinia = createPinia()
+ setActivePinia(pinia)
+ routeState.query = { redirect: '/billing?plan=pro' }
+ locationState.current = { href: 'http://localhost/login' }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0',
+ })
+ })
+
+ afterEach(() => {
+ vi.unstubAllGlobals()
+ })
+
+ it('starts the open WeChat OAuth flow with the current redirect target when open mode is configured', async () => {
+ seedPublicSettings({
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.text()).toContain('Mock WeChat')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('uses mp mode inside the WeChat browser when mp mode is configured', async () => {
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0 MicroMessenger',
+ })
+ seedPublicSettings({
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: true,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('disables the button outside the WeChat browser when only mp mode is configured', async () => {
+ seedPublicSettings({
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: true,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.get('button').attributes('disabled')).toBeDefined()
+ expect(wrapper.text()).toContain('MOCK-WECHAT-BROWSER-ONLY')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toBe('http://localhost/login')
+ })
+
+ it('disables the button inside the WeChat browser when only open mode is configured', async () => {
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0 MicroMessenger',
+ })
+ seedPublicSettings({
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.get('button').attributes('disabled')).toBeDefined()
+ expect(wrapper.text()).toContain('MOCK-SYSTEM-BROWSER-ONLY')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toBe('http://localhost/login')
+ })
+
+ it('uses the legacy overall enabled flag when per-mode settings are not present', async () => {
+ seedPublicSettings({
+ wechat_oauth_enabled: true,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('shows the localized not-configured hint when WeChat OAuth is unavailable', async () => {
+ seedPublicSettings({
+ wechat_oauth_enabled: false,
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: false,
+ })
+
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.text()).toContain('MOCK-NOT-CONFIGURED')
+ })
+})
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index 7770e658..b3679107 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -617,66 +617,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
const openaiModels = {
- 'gpt-5-codex': {
- name: 'GPT-5 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex': {
- name: 'GPT-5.1 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex-max': {
- name: 'GPT-5.1 Codex Max',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex-mini': {
- name: 'GPT-5.1 Codex Mini',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
'gpt-5.2': {
name: 'GPT-5.2',
limit: {
@@ -725,22 +665,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
- 'gpt-5.4-nano': {
- name: 'GPT-5.4 Nano',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {},
- xhigh: {}
- }
- },
'gpt-5.3-codex-spark': {
name: 'GPT-5.3 Codex Spark',
limit: {
@@ -773,22 +697,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
- 'gpt-5.2-codex': {
- name: 'GPT-5.2 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {},
- xhigh: {}
- }
- },
'codex-mini-latest': {
name: 'Codex Mini',
limit: {
diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
index 98b5dede..f7db586a 100644
--- a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
+++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
@@ -17,7 +17,7 @@ vi.mock('@/composables/useClipboard', () => ({
import UseKeyModal from '../UseKeyModal.vue'
describe('UseKeyModal', () => {
- it('renders updated GPT-5.4 mini/nano names in OpenCode config', async () => {
+ it('renders GPT-5.4 mini entry in OpenCode config', async () => {
const wrapper = mount(UseKeyModal, {
props: {
show: true,
@@ -48,6 +48,6 @@ describe('UseKeyModal', () => {
const codeBlock = wrapper.find('pre code')
expect(codeBlock.exists()).toBe(true)
expect(codeBlock.text()).toContain('"name": "GPT-5.4 Mini"')
- expect(codeBlock.text()).toContain('"name": "GPT-5.4 Nano"')
+ expect(codeBlock.text()).not.toContain('"name": "GPT-5.4 Nano"')
})
})
diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue
index fbcab521..306f1429 100644
--- a/frontend/src/components/layout/AppHeader.vue
+++ b/frontend/src/components/layout/AppHeader.vue
@@ -74,10 +74,14 @@
class="flex items-center gap-2 rounded-xl p-1.5 transition-colors hover:bg-gray-100 dark:hover:bg-dark-800"
aria-label="User Menu"
>
-
- {{ userInitials }}
+
+
+
{{ userInitials }}
@@ -232,6 +236,7 @@ const dropdownOpen = ref(false)
const dropdownRef = ref
(null)
const contactInfo = computed(() => appStore.contactInfo)
const docUrl = computed(() => appStore.docUrl)
+const avatarUrl = computed(() => user.value?.avatar_url?.trim() || '')
// 只在标准模式的管理员下显示新手引导按钮
const showOnboardingButton = computed(() => {
diff --git a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
index 118c7615..592ce8a3 100644
--- a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
+++ b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
@@ -21,7 +21,7 @@ describe('AppSidebar custom SVG styles', () => {
describe('AppSidebar header styles', () => {
it('does not clip the version badge dropdown', () => {
- const sidebarHeaderBlockMatch = styleSource.match(/\.sidebar-header\s*\{[\s\S]*?\n \}/)
+ const sidebarHeaderBlockMatch = styleSource.match(/\.sidebar-header\s*\{[\s\S]*?\n {2}\}/)
const sidebarBrandBlockMatch = componentSource.match(/\.sidebar-brand\s*\{[\s\S]*?\n\}/)
expect(sidebarHeaderBlockMatch).not.toBeNull()
diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue
index 10c1bfea..624ddcdd 100644
--- a/frontend/src/components/payment/PaymentProviderDialog.vue
+++ b/frontend/src/components/payment/PaymentProviderDialog.vue
@@ -88,13 +88,24 @@
v-model="config[field.key]"
rows="3"
class="input font-mono text-xs"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
+ spellcheck="false"
+ :placeholder="editing ? t('admin.accounts.leaveEmptyToKeep') : ''"
/>
= {}
for (const [k, v] of Object.entries(config)) {
if (!v || !v.trim()) continue
- // Skip masked values — backend keeps existing credentials
- if (v === '••••••••') continue
filteredConfig[k] = v
}
@@ -470,7 +482,8 @@ function loadProvider(provider: ProviderInstance) {
form.refund_enabled = provider.refund_enabled
form.allow_user_refund = provider.allow_user_refund
clearConfig()
- // Pre-fill config from API response (non-sensitive in cleartext, sensitive masked as ••••••••)
+ // Pre-fill config from API response. Backend omits sensitive fields entirely,
+ // so those inputs stay blank — submitting blank preserves the stored secret.
if (provider.config) {
for (const [k, v] of Object.entries(provider.config)) {
// Skip notifyUrl/returnUrl — they are derived from callbackBaseUrl
diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue
index b9026e78..09d273cc 100644
--- a/frontend/src/components/payment/PaymentQRDialog.vue
+++ b/frontend/src/components/payment/PaymentQRDialog.vue
@@ -78,8 +78,8 @@ import Icon from '@/components/icons/Icon.vue'
import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { extractI18nErrorMessage } from '@/utils/apiError'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import QRCode from 'qrcode'
import alipayIcon from '@/assets/icons/alipay.svg'
@@ -147,7 +147,7 @@ function getLogoForType(): string | null {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
}
}
@@ -222,7 +222,7 @@ async function handleCancel() {
cleanup()
emit('close')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 974dee66..0fd444ac 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -124,8 +124,8 @@ import { useI18n } from 'vue-i18n'
import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { extractI18nErrorMessage } from '@/utils/apiError'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import Icon from '@/components/icons/Icon.vue'
import QRCode from 'qrcode'
@@ -141,7 +141,9 @@ const props = defineProps<{
orderType?: string
}>()
-const emit = defineEmits<{ done: []; success: [] }>()
+type PaymentOutcome = 'success' | 'cancelled' | 'expired'
+
+const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>()
const { t } = useI18n()
const paymentStore = usePaymentStore()
@@ -154,7 +156,7 @@ const cancelling = ref(false)
const paidOrder = ref(null)
// Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired'
-const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null)
+const outcome = ref(null)
let pollTimer: ReturnType | null = null
let countdownTimer: ReturnType | null = null
@@ -192,12 +194,25 @@ const countdownDisplay = computed(() => {
return m.toString().padStart(2, '0') + ':' + s.toString().padStart(2, '0')
})
+function isSuccessStatus(status: string | null | undefined): boolean {
+ return status === 'COMPLETED' || status === 'PAID' || status === 'RECHARGING'
+}
+
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ const win = window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
+ if (!win || win.closed) {
+ window.location.href = props.payUrl
+ }
}
}
+function setOutcome(next: PaymentOutcome) {
+ if (outcome.value === next) return
+ outcome.value = next
+ emit('settled', next)
+}
+
async function renderQR() {
await nextTick()
if (!qrCanvas.value || !qrUrl.value) return
@@ -211,26 +226,26 @@ async function pollStatus() {
if (!props.orderId || outcome.value) return
const order = await paymentStore.pollOrderStatus(props.orderId)
if (!order) return
- if (order.status === 'COMPLETED' || order.status === 'PAID') {
+ if (isSuccessStatus(order.status)) {
cleanup()
paidOrder.value = order
- outcome.value = 'success'
+ setOutcome('success')
emit('success')
} else if (order.status === 'CANCELLED') {
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} else if (order.status === 'EXPIRED' || order.status === 'FAILED') {
cleanup()
- outcome.value = 'expired'
+ setOutcome('expired')
}
}
function startCountdown(seconds: number) {
remainingSeconds.value = Math.max(0, seconds)
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); return }
countdownTimer = setInterval(() => {
remainingSeconds.value--
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() }
}, 1000)
}
@@ -240,9 +255,9 @@ async function handleCancel() {
try {
await paymentAPI.cancelOrder(props.orderId)
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue
index b8fd55ef..bdb0dd6b 100644
--- a/frontend/src/components/payment/StripePaymentInline.vue
+++ b/frontend/src/components/payment/StripePaymentInline.vue
@@ -67,10 +67,10 @@
import { ref, onMounted, nextTick } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRouter } from 'vue-router'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { paymentAPI } from '@/api/payment'
import { useAppStore } from '@/stores'
-import { STRIPE_POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { Stripe, StripeElements } from '@stripe/stripe-js'
import Icon from '@/components/icons/Icon.vue'
@@ -132,7 +132,7 @@ onMounted(async () => {
selectedType.value = event.value.type
})
} catch (err: unknown) {
- initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed'))
+ initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed'))
} finally {
loading.value = false
}
@@ -151,7 +151,7 @@ async function handlePay() {
amount: String(props.payAmount),
},
}).href
- const popup = window.open(popupUrl, 'paymentPopup', STRIPE_POPUP_WINDOW_FEATURES)
+ const popup = window.open(popupUrl, 'paymentPopup', getPaymentPopupFeatures())
const onReady = (event: MessageEvent) => {
if (event.source !== popup || event.data?.type !== 'STRIPE_POPUP_READY') return
@@ -186,7 +186,7 @@ async function handlePay() {
emit('success')
}
} catch (err: unknown) {
- error.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed'))
} finally {
submitting.value = false
}
@@ -199,7 +199,7 @@ async function handleCancel() {
await paymentAPI.cancelOrder(props.orderId)
emit('back')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
new file mode 100644
index 00000000..d7017e1f
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
@@ -0,0 +1,99 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+const pollOrderStatus = vi.hoisted(() => vi.fn())
+const cancelOrder = vi.hoisted(() => vi.fn())
+const showError = vi.hoisted(() => vi.fn())
+const toCanvas = vi.hoisted(() => vi.fn())
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+ }
+})
+
+vi.mock('@/stores/payment', () => ({
+ usePaymentStore: () => ({
+ pollOrderStatus,
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError,
+ }),
+}))
+
+vi.mock('@/api/payment', () => ({
+ paymentAPI: {
+ cancelOrder,
+ },
+}))
+
+vi.mock('qrcode', () => ({
+ default: {
+ toCanvas,
+ },
+}))
+
+import PaymentStatusPanel from '../PaymentStatusPanel.vue'
+
+const orderFactory = (status: string) => ({
+ id: 42,
+ user_id: 9,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ payment_type: 'alipay',
+ out_trade_no: 'sub2_20260420abcd1234',
+ status,
+ order_type: 'balance',
+ created_at: '2026-04-20T12:00:00Z',
+ expires_at: '2099-01-01T12:30:00Z',
+ refund_amount: 0,
+})
+
+describe('PaymentStatusPanel', () => {
+ beforeEach(() => {
+ vi.useFakeTimers()
+ pollOrderStatus.mockReset()
+ cancelOrder.mockReset()
+ showError.mockReset()
+ toCanvas.mockReset().mockResolvedValue(undefined)
+ })
+
+ afterEach(() => {
+ vi.useRealTimers()
+ })
+
+ it('treats RECHARGING as a successful terminal state', async () => {
+ pollOrderStatus.mockResolvedValue(orderFactory('RECHARGING'))
+
+ const wrapper = mount(PaymentStatusPanel, {
+ props: {
+ orderId: 42,
+ qrCode: 'https://pay.example.com/qr/42',
+ expiresAt: '2099-01-01T12:30:00Z',
+ paymentType: 'alipay',
+ orderType: 'balance',
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ },
+ },
+ })
+
+ await flushPromises()
+ await vi.advanceTimersByTimeAsync(3000)
+ await flushPromises()
+
+ expect(pollOrderStatus).toHaveBeenCalledWith(42)
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.emitted('success')).toHaveLength(1)
+ })
+})
diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
new file mode 100644
index 00000000..48c77dfb
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
@@ -0,0 +1,302 @@
+import { describe, expect, it } from 'vitest'
+import type { CreateOrderResult, MethodLimit } from '@/types/payment'
+import {
+ buildCreateOrderPayload,
+ decidePaymentLaunch,
+ getVisibleMethods,
+ readPaymentRecoverySnapshot,
+ type PaymentRecoverySnapshot,
+} from '@/components/payment/paymentFlow'
+
+function methodLimit(overrides: Partial = {}): MethodLimit {
+ return {
+ daily_limit: 0,
+ daily_used: 0,
+ daily_remaining: 0,
+ single_min: 0,
+ single_max: 0,
+ fee_rate: 0,
+ available: true,
+ ...overrides,
+ }
+}
+
+function createOrderResult(overrides: Partial = {}): CreateOrderResult {
+ return {
+ order_id: 101,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ expires_at: '2099-01-01T00:10:00.000Z',
+ ...overrides,
+ }
+}
+
+describe('getVisibleMethods', () => {
+ it('filters hidden provider methods and normalizes aliases', () => {
+ const visible = getVisibleMethods({
+ alipay_direct: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ stripe: methodLimit({ fee_rate: 3 }),
+ })
+
+ expect(visible).toEqual({
+ alipay: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ })
+ })
+
+ it('prefers canonical visible methods over aliases when both exist', () => {
+ const visible = getVisibleMethods({
+ alipay: methodLimit({ single_min: 2 }),
+ alipay_direct: methodLimit({ single_min: 9 }),
+ wxpay_direct: methodLimit({ fee_rate: 1.2 }),
+ })
+
+ expect(visible.alipay.single_min).toBe(2)
+ expect(visible.wxpay.fee_rate).toBe(1.2)
+ })
+})
+
+describe('decidePaymentLaunch', () => {
+ it('uses Stripe popup waiting flow for desktop Alipay client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ resume_token: 'resume-1',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('stripe_popup')
+ expect(decision.paymentState.paymentType).toBe('alipay')
+ expect(decision.stripeMethod).toBe('alipay')
+ expect(decision.recovery.resumeToken).toBe('resume-1')
+ expect(decision.recovery.outTradeNo).toBe('')
+ })
+
+ it('uses Stripe route flow for mobile WeChat client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'subscription',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('stripe_route')
+ expect(decision.stripeMethod).toBe('wechat_pay')
+ expect(decision.paymentState.orderType).toBe('subscription')
+ })
+
+ it('keeps hosted redirect metadata for recovery flows', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/session/abc',
+ payment_mode: 'popup',
+ resume_token: 'resume-2',
+ out_trade_no: 'sub2_abc',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('redirect_waiting')
+ expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc')
+ expect(decision.recovery.paymentMode).toBe('popup')
+ expect(decision.recovery.outTradeNo).toBe('sub2_abc')
+ expect(decision.recovery.resumeToken).toBe('resume-2')
+ })
+
+ it('prefers redirect on mobile when both pay_url and qr_code are present', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/mobile/session',
+ qr_code: 'https://pay.example.com/qr/session',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('redirect_waiting')
+ expect(decision.paymentState.payUrl).toBe('https://pay.example.com/mobile/session')
+ expect(decision.paymentState.qrCode).toBe('https://pay.example.com/qr/session')
+ })
+
+ it('keeps QR flow on desktop when both pay_url and qr_code are present', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/desktop/session',
+ qr_code: 'https://pay.example.com/qr/session',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('qr_waiting')
+ expect(decision.paymentState.qrCode).toBe('https://pay.example.com/qr/session')
+ })
+
+ it('returns wechat oauth launch when backend requires in-app authorization', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ result_type: 'oauth_required',
+ payment_type: 'wxpay',
+ oauth: {
+ authorize_url: '/api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay',
+ appid: 'wx123',
+ scope: 'snsapi_base',
+ redirect_url: '/auth/wechat/payment/callback',
+ },
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('wechat_oauth')
+ expect(decision.oauth?.authorize_url).toContain('/api/v1/auth/oauth/wechat/payment/start')
+ expect(decision.paymentState.paymentType).toBe('wxpay')
+ })
+
+ it('returns wechat jsapi launch when backend has a jsapi payload ready', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ result_type: 'jsapi_ready',
+ payment_type: 'wxpay',
+ jsapi: {
+ appId: 'wx123',
+ timeStamp: '1712345678',
+ nonceStr: 'nonce-123',
+ package: 'prepay_id=wx123',
+ signType: 'RSA',
+ paySign: 'signed-payload',
+ },
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'subscription',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('wechat_jsapi')
+ expect(decision.jsapi?.appId).toBe('wx123')
+ expect(decision.paymentState.orderType).toBe('subscription')
+ })
+})
+
+describe('buildCreateOrderPayload', () => {
+ it('normalizes visible method aliases and attaches a canonical result URL', () => {
+ expect(buildCreateOrderPayload({
+ amount: 88,
+ paymentType: 'alipay_direct',
+ orderType: 'balance',
+ origin: 'https://app.example.com/',
+ isWechatBrowser: false,
+ })).toEqual({
+ amount: 88,
+ payment_type: 'alipay',
+ order_type: 'balance',
+ return_url: 'https://app.example.com/payment/result',
+ payment_source: 'hosted_redirect',
+ })
+ })
+
+ it('uses WeChat in-app resume source for visible WeChat payments in the WeChat browser', () => {
+ expect(buildCreateOrderPayload({
+ amount: 128,
+ paymentType: 'wxpay',
+ orderType: 'subscription',
+ planId: 7,
+ origin: 'https://app.example.com',
+ isWechatBrowser: true,
+ })).toEqual({
+ amount: 128,
+ payment_type: 'wxpay',
+ order_type: 'subscription',
+ plan_id: 7,
+ return_url: 'https://app.example.com/payment/result',
+ payment_source: 'wechat_in_app_resume',
+ })
+ })
+})
+
+describe('readPaymentRecoverySnapshot', () => {
+ it('restores an unexpired snapshot when the resume token matches', () => {
+ const snapshot: PaymentRecoverySnapshot = {
+ orderId: 33,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/33',
+ outTradeNo: 'sub2_33',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-33',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }
+
+ const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'resume-33',
+ })
+
+ expect(restored?.orderId).toBe(33)
+ })
+
+ it('drops expired or mismatched recovery snapshots', () => {
+ const expiredSnapshot: PaymentRecoverySnapshot = {
+ orderId: 55,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2024-01-01T00:10:00.000Z',
+ paymentType: 'wxpay',
+ payUrl: 'https://pay.example.com/session/55',
+ outTradeNo: 'sub2_55',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-55',
+ createdAt: Date.UTC(2024, 0, 1, 0, 0, 0),
+ }
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), {
+ now: Date.UTC(2024, 0, 1, 0, 20, 0),
+ resumeToken: 'resume-55',
+ })).toBeNull()
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify({
+ ...expiredSnapshot,
+ outTradeNo: 'sub2_55',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ }), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'other-token',
+ })).toBeNull()
+ })
+
+ it('keeps backward compatibility with snapshots written before outTradeNo existed', () => {
+ const restored = readPaymentRecoverySnapshot(JSON.stringify({
+ orderId: 44,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/44',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-44',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'resume-44',
+ })
+
+ expect(restored?.orderId).toBe(44)
+ expect(restored?.outTradeNo).toBe('')
+ })
+})
diff --git a/frontend/src/components/payment/__tests__/providerConfig.spec.ts b/frontend/src/components/payment/__tests__/providerConfig.spec.ts
new file mode 100644
index 00000000..6a4c9c26
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/providerConfig.spec.ts
@@ -0,0 +1,20 @@
+import { describe, expect, it } from 'vitest'
+import { PROVIDER_CONFIG_FIELDS } from '@/components/payment/providerConfig'
+
+function findField(key: string) {
+ const fields = PROVIDER_CONFIG_FIELDS.wxpay || []
+ return fields.find(field => field.key === key)
+}
+
+describe('PROVIDER_CONFIG_FIELDS.wxpay', () => {
+ it('keeps admin form validation aligned with backend-required credentials', () => {
+ expect(findField('publicKeyId')?.optional).toBeFalsy()
+ expect(findField('certSerial')?.optional).toBeFalsy()
+ })
+
+ it('exposes optional mp and H5 metadata fields for WeChat-specific flows', () => {
+ expect(findField('mpAppId')?.optional).toBe(true)
+ expect(findField('h5AppName')?.optional).toBe(true)
+ expect(findField('h5AppUrl')?.optional).toBe(true)
+ })
+})
diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts
new file mode 100644
index 00000000..05f36ed0
--- /dev/null
+++ b/frontend/src/components/payment/paymentFlow.ts
@@ -0,0 +1,269 @@
+import type {
+ CreateOrderRequest,
+ CreateOrderResult,
+ MethodLimit,
+ OrderType,
+ WechatJSAPIPayload,
+ WechatOAuthInfo,
+} from '@/types/payment'
+
+export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current'
+
+const VISIBLE_METHOD_ALIASES = {
+ alipay: 'alipay',
+ alipay_direct: 'alipay',
+ wxpay: 'wxpay',
+ wxpay_direct: 'wxpay',
+} as const
+
+export type VisiblePaymentMethod = 'alipay' | 'wxpay'
+export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
+export type PaymentLaunchKind =
+ | 'qr_waiting'
+ | 'redirect_waiting'
+ | 'stripe_popup'
+ | 'stripe_route'
+ | 'wechat_oauth'
+ | 'wechat_jsapi'
+ | 'unhandled'
+
+export interface PaymentRecoverySnapshot {
+ orderId: number
+ amount: number
+ qrCode: string
+ expiresAt: string
+ paymentType: string
+ payUrl: string
+ outTradeNo: string
+ clientSecret: string
+ payAmount: number
+ orderType: OrderType | ''
+ paymentMode: string
+ resumeToken: string
+ createdAt: number
+}
+
+export interface PaymentLaunchContext {
+ visibleMethod: string
+ orderType: OrderType
+ isMobile: boolean
+ isWechatBrowser?: boolean
+ now?: number
+ stripePopupUrl?: string
+ stripeRouteUrl?: string
+}
+
+export interface PaymentLaunchDecision {
+ kind: PaymentLaunchKind
+ paymentState: PaymentRecoverySnapshot
+ recovery: PaymentRecoverySnapshot
+ stripeMethod?: StripeVisibleMethod
+ oauth?: WechatOAuthInfo
+ jsapi?: WechatJSAPIPayload
+}
+
+export interface BuildCreateOrderPayloadInput {
+ amount: number
+ paymentType: string
+ orderType: OrderType
+ planId?: number
+ origin?: string
+ isWechatBrowser: boolean
+}
+
+type CreateOrderFlowResult = CreateOrderResult & {
+ resume_token?: string
+}
+
+type StorageWriter = Pick
+
+export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' {
+ const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES]
+ return normalized ?? ''
+}
+
+export function getVisibleMethods(methods: Record): Record {
+ const visible: Record = {}
+
+ Object.entries(methods).forEach(([type, limit]) => {
+ const normalized = normalizeVisibleMethod(type)
+ if (!normalized) return
+
+ const isCanonical = type === normalized
+ const existing = visible[normalized]
+ if (!existing || isCanonical) {
+ visible[normalized] = { ...limit }
+ }
+ })
+
+ return visible
+}
+
+export function buildCreateOrderPayload(input: BuildCreateOrderPayloadInput): CreateOrderRequest {
+ const visibleMethod = normalizeVisibleMethod(input.paymentType) || input.paymentType.trim()
+ const normalizedOrigin = (input.origin || '').trim().replace(/\/+$/, '')
+ const payload: CreateOrderRequest = {
+ amount: input.amount,
+ payment_type: visibleMethod,
+ order_type: input.orderType,
+ payment_source: visibleMethod === 'wxpay' && input.isWechatBrowser
+ ? 'wechat_in_app_resume'
+ : 'hosted_redirect',
+ }
+
+ if (input.planId) {
+ payload.plan_id = input.planId
+ }
+ if (normalizedOrigin) {
+ payload.return_url = `${normalizedOrigin}/payment/result`
+ }
+
+ return payload
+}
+
+export function decidePaymentLaunch(
+ result: CreateOrderFlowResult,
+ context: PaymentLaunchContext,
+): PaymentLaunchDecision {
+ const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod
+ const baseState = createPaymentRecoverySnapshot({
+ orderId: result.order_id,
+ amount: result.amount,
+ qrCode: result.qr_code || '',
+ expiresAt: result.expires_at || '',
+ paymentType: visibleMethod,
+ payUrl: result.pay_url || '',
+ outTradeNo: result.out_trade_no || '',
+ clientSecret: result.client_secret || '',
+ payAmount: result.pay_amount,
+ orderType: context.orderType,
+ paymentMode: (result.payment_mode || '').trim(),
+ resumeToken: result.resume_token || '',
+ }, context.now)
+
+ if (baseState.clientSecret) {
+ const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
+ const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile
+ ? 'stripe_popup'
+ : 'stripe_route'
+ const payUrl = kind === 'stripe_popup'
+ ? context.stripePopupUrl || context.stripeRouteUrl || ''
+ : context.stripeRouteUrl || context.stripePopupUrl || ''
+ const paymentState = { ...baseState, payUrl }
+ return { kind, paymentState, recovery: paymentState, stripeMethod }
+ }
+
+ if (result.result_type === 'oauth_required' && result.oauth?.authorize_url) {
+ return { kind: 'wechat_oauth', paymentState: baseState, recovery: baseState, oauth: result.oauth }
+ }
+
+ const jsapiPayload = result.jsapi ?? result.jsapi_payload
+ if (result.result_type === 'jsapi_ready' && jsapiPayload) {
+ return { kind: 'wechat_jsapi', paymentState: baseState, recovery: baseState, jsapi: jsapiPayload }
+ }
+
+ const normalizedPaymentMode = baseState.paymentMode.trim().toLowerCase()
+ const prefersRedirect = normalizedPaymentMode === 'redirect'
+ || normalizedPaymentMode === 'popup'
+ || (context.isMobile && !!baseState.payUrl)
+ const prefersQr = normalizedPaymentMode === 'qrcode'
+ || normalizedPaymentMode === 'native'
+ || (!prefersRedirect && !!baseState.qrCode)
+
+ if (visibleMethod === 'wxpay' && context.isWechatBrowser && baseState.payUrl && !baseState.qrCode) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (prefersRedirect && baseState.payUrl) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (prefersQr && baseState.qrCode) {
+ return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (baseState.payUrl) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ return { kind: 'unhandled', paymentState: baseState, recovery: baseState }
+}
+
+export function createPaymentRecoverySnapshot(
+ state: Omit,
+ now = Date.now(),
+): PaymentRecoverySnapshot {
+ return {
+ ...state,
+ createdAt: now,
+ }
+}
+
+export function writePaymentRecoverySnapshot(
+ storage: StorageWriter,
+ snapshot: PaymentRecoverySnapshot,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.setItem(key, JSON.stringify(snapshot))
+}
+
+export function clearPaymentRecoverySnapshot(
+ storage: Pick,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.removeItem(key)
+}
+
+export function readPaymentRecoverySnapshot(
+ raw: string | null | undefined,
+ options: { now?: number; resumeToken?: string } = {},
+): PaymentRecoverySnapshot | null {
+ if (!raw) return null
+
+ try {
+ const parsed = JSON.parse(raw) as Partial
+ if (
+ typeof parsed.orderId !== 'number'
+ || typeof parsed.amount !== 'number'
+ || typeof parsed.qrCode !== 'string'
+ || typeof parsed.expiresAt !== 'string'
+ || typeof parsed.paymentType !== 'string'
+ || typeof parsed.payUrl !== 'string'
+ || (parsed.outTradeNo != null && typeof parsed.outTradeNo !== 'string')
+ || typeof parsed.clientSecret !== 'string'
+ || typeof parsed.payAmount !== 'number'
+ || typeof parsed.paymentMode !== 'string'
+ || typeof parsed.resumeToken !== 'string'
+ || typeof parsed.createdAt !== 'number'
+ ) {
+ return null
+ }
+
+ const now = options.now ?? Date.now()
+ const expiresAt = Date.parse(parsed.expiresAt)
+ if (Number.isFinite(expiresAt) && expiresAt <= now) {
+ return null
+ }
+ if (options.resumeToken && parsed.resumeToken !== options.resumeToken) {
+ return null
+ }
+
+ return {
+ orderId: parsed.orderId,
+ amount: parsed.amount,
+ qrCode: parsed.qrCode,
+ expiresAt: parsed.expiresAt,
+ paymentType: parsed.paymentType,
+ payUrl: parsed.payUrl,
+ outTradeNo: parsed.outTradeNo || '',
+ clientSecret: parsed.clientSecret,
+ payAmount: parsed.payAmount,
+ orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance',
+ paymentMode: parsed.paymentMode,
+ resumeToken: parsed.resumeToken,
+ createdAt: parsed.createdAt,
+ }
+ } catch {
+ return null
+ }
+}
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index a83787fd..67ffdec1 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -43,11 +43,24 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct',
export const PAYMENT_MODE_QRCODE = 'qrcode'
export const PAYMENT_MODE_POPUP = 'popup'
-/** Window features for payment popup windows */
-export const POPUP_WINDOW_FEATURES = 'width=1000,height=750,left=100,top=80,scrollbars=yes,resizable=yes'
+/** Preferred popup size for payment gateways. Alipay's standard checkout
+ * (QR + account login panel) needs ~1200×900 to render without any scrolling. */
+const PAYMENT_POPUP_PREFERRED_WIDTH = 1250
+const PAYMENT_POPUP_PREFERRED_HEIGHT = 900
-/** Wider popup for Stripe redirect methods (Alipay checkout page needs ~1200px) */
-export const STRIPE_POPUP_WINDOW_FEATURES = 'width=1250,height=780,left=80,top=60,scrollbars=yes,resizable=yes'
+/** Build a window.open features string sized to fit within the current screen
+ * while preferring the above dimensions. Centers the popup on the available
+ * work area so nothing is clipped on smaller laptop displays. */
+export function getPaymentPopupFeatures(): string {
+ const screen = typeof window !== 'undefined' ? window.screen : null
+ const availW = screen?.availWidth ?? PAYMENT_POPUP_PREFERRED_WIDTH
+ const availH = screen?.availHeight ?? PAYMENT_POPUP_PREFERRED_HEIGHT
+ const width = Math.min(PAYMENT_POPUP_PREFERRED_WIDTH, availW - 40)
+ const height = Math.min(PAYMENT_POPUP_PREFERRED_HEIGHT, availH - 40)
+ const left = Math.max(0, Math.floor((availW - width) / 2))
+ const top = Math.max(0, Math.floor((availH - height) / 2))
+ return `width=${width},height=${height},left=${left},top=${top},scrollbars=yes,resizable=yes`
+}
/** Webhook paths for each provider (relative to origin). */
export const WEBHOOK_PATHS: Record = {
@@ -83,12 +96,15 @@ export const PROVIDER_CONFIG_FIELDS: Record = {
],
wxpay: [
{ key: 'appId', label: 'App ID', sensitive: false },
+ { key: 'mpAppId', label: '', sensitive: false, optional: true },
{ key: 'mchId', label: '', sensitive: false },
{ key: 'privateKey', label: '', sensitive: true },
{ key: 'apiV3Key', label: '', sensitive: true },
+ { key: 'certSerial', label: '', sensitive: false },
{ key: 'publicKey', label: '', sensitive: true },
- { key: 'publicKeyId', label: '', sensitive: false, optional: true },
- { key: 'certSerial', label: '', sensitive: false, optional: true },
+ { key: 'publicKeyId', label: '', sensitive: false },
+ { key: 'h5AppName', label: '', sensitive: false, optional: true },
+ { key: 'h5AppUrl', label: '', sensitive: false, optional: true },
],
stripe: [
{ key: 'secretKey', label: '', sensitive: true },
diff --git a/frontend/src/components/user/profile/ProfileAccountBindingsCard.vue b/frontend/src/components/user/profile/ProfileAccountBindingsCard.vue
new file mode 100644
index 00000000..f1cf54a9
--- /dev/null
+++ b/frontend/src/components/user/profile/ProfileAccountBindingsCard.vue
@@ -0,0 +1,36 @@
+
+
+
+
+
diff --git a/frontend/src/components/user/profile/ProfileAvatarCard.vue b/frontend/src/components/user/profile/ProfileAvatarCard.vue
new file mode 100644
index 00000000..9ff26853
--- /dev/null
+++ b/frontend/src/components/user/profile/ProfileAvatarCard.vue
@@ -0,0 +1,270 @@
+
+
+
+
+ {{ t('profile.avatar.title') }}
+
+
+ {{ t('profile.avatar.description') }}
+
+
+
+
+
+
+
{{ avatarInitial }}
+
+
+
+
+
+ {{ t('profile.avatar.title') }}
+
+
+ {{ displayName }}
+
+
+ {{ t('profile.avatar.uploadHint') }}
+
+
+
+
+
+
+ {{ t('profile.avatar.uploadAction') }}
+
+
+
+ {{ t('common.save') }}
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/profile/ProfileEditForm.vue b/frontend/src/components/user/profile/ProfileEditForm.vue
index 2750840a..e1441921 100644
--- a/frontend/src/components/user/profile/ProfileEditForm.vue
+++ b/frontend/src/components/user/profile/ProfileEditForm.vue
@@ -1,12 +1,20 @@
-
-
+
+
{{ t('profile.editProfile') }}
-